From 33f85fadf63acc71eaab5a3da8b7ee16fff1a002 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 14 Oct 2024 19:16:23 +0200 Subject: [PATCH 001/170] add --- .../pipelines/custom_pipeline_builder.py | 1763 +++++++++++++++++ 1 file changed, 1763 insertions(+) create mode 100644 src/diffusers/pipelines/custom_pipeline_builder.py diff --git a/src/diffusers/pipelines/custom_pipeline_builder.py b/src/diffusers/pipelines/custom_pipeline_builder.py new file mode 100644 index 000000000000..d3603d088932 --- /dev/null +++ b/src/diffusers/pipelines/custom_pipeline_builder.py @@ -0,0 +1,1763 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple, Union + +import PIL +import torch +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +from ..configuration_utils import ConfigMixin +from ..image_processor import VaeImageProcessor +from ..loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin +from ..models import ImageProjection +from ..models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor +from ..models.lora import adjust_lora_scale_text_encoder +from ..utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ..utils.torch_utils import randn_tensor +from .pipeline_loading_utils import _fetch_class_library_tuple +from .pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from .stable_diffusion_xl import ( + StableDiffusionXLPipeline, + StableDiffusionXLPipelineOutput, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class CustomPipeline(ConfigMixin): + """ + Base class for all custom pipelines built with CustomPipelineBuilder. + + [`CustomPipeline`] stores all components (models, schedulers, and processors) for diffusion pipelines. Unlike + [`DiffusionPipeline`], it's designed to be used exclusively with [`CustomPipelineBuilder`] and does not have a + `__call__` method. It cannot be called directly and must be run via the builder's run_pipeline method. + Additionally, it does not include methods for loading, downloading, or saving models, focusing only on + inference-related tasks, such as: + + - move all PyTorch modules to the device of your choice + - enable/disable the progress bar for the denoising iteration + + Usage: This class should not be instantiated directly. Instead, use CustomPipelineBuilder to create and configure a + CustomPipeline instance. + + Example: + builder = CustomPipelineBuilder("SDXL") builder.add_blocks([InputStep(), TextEncoderStep(), ...]) result = + builder.run_pipeline(prompt="A beautiful sunset") + + Class Attributes: + config_name (str): Filename for the configuration storing component class and module names. + + Note: This class is part of a modular pipeline system and is intended to be used in conjunction with + CustomPipelineBuilder for maximum flexibility and customization in diffusion pipelines. + """ + + config_name = "model_index.json" + model_cpu_offload_seq = None + hf_device_map = None + _exclude_from_cpu_offload = [] + + def __init__(self): + super().__init__() + self.register_to_config() + self.builder = None + + def __repr__(self): + if self.builder: + return repr(self.builder) + return "CustomPipeline (not fully initialized)" + + # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline.register_modules + def register_modules(self, **kwargs): + for name, module in kwargs.items(): + # retrieve library + if module is None or isinstance(module, (tuple, list)) and module[0] is None: + register_dict = {name: (None, None)} + else: + library, class_name = _fetch_class_library_tuple(module) + register_dict = {name: (library, class_name)} + + # save model index config + self.register_to_config(**register_dict) + + # set models + setattr(self, name, module) + + @property + def device(self) -> torch.device: + r""" + Returns: + `torch.device`: The torch device on which the pipeline is located. + """ + modules = self.components.values() + modules = [m for m in modules if isinstance(m, torch.nn.Module)] + + for module in modules: + return module.device + + return torch.device("cpu") + + @property + # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + [`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from + Accelerate's module hooks. + """ + for name, model in self.components.items(): + if not isinstance(model, torch.nn.Module) or name in self._exclude_from_cpu_offload: + continue + + if not hasattr(model, "_hf_hook"): + return self.device + for module in model.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + @property + def dtype(self) -> torch.dtype: + r""" + Returns: + `torch.dtype`: The torch dtype on which the pipeline is located. + """ + modules = self.components.values() + modules = [m for m in modules if isinstance(m, torch.nn.Module)] + + for module in modules: + return module.dtype + + return torch.float32 + + @property + def components(self) -> Dict[str, Any]: + r""" + The `self.components` property returns all modules needed to initialize the pipeline, as defined by the + pipeline blocks. + + Returns (`dict`): + A dictionary containing all the components defined in the pipeline blocks. + """ + if not hasattr(self, "builder") or self.builder is None: + raise ValueError("Pipeline builder is not set. Cannot retrieve components.") + + components = {} + for block in self.builder.pipeline_blocks: + components.update(block.components) + + # Check if all items in config that are also in any block's components are included + for key in self.config.keys(): + if any(key in block.components for block in self.builder.pipeline_blocks): + if key not in components: + components[key] = getattr(self, key, None) + + return components + + # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline.progress_bar + def progress_bar(self, iterable=None, total=None): + if not hasattr(self, "_progress_bar_config"): + self._progress_bar_config = {} + elif not isinstance(self._progress_bar_config, dict): + raise ValueError( + f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." + ) + + if iterable is not None: + return tqdm(iterable, **self._progress_bar_config) + elif total is not None: + return tqdm(total=total, **self._progress_bar_config) + else: + raise ValueError("Either `total` or `iterable` has to be defined.") + + # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline.set_progress_bar_config + def set_progress_bar_config(self, **kwargs): + self._progress_bar_config = kwargs + + def __call__(self, *args, **kwargs): + raise NotImplementedError("__call__ is not implemented for CustomPipeline") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class CFGGuider: + """ + This class is used to guide the pipeline with CFG (Classifier-Free Guidance). + """ + + def prepare_inputs_for_cfg( + self, negative_cond_input: torch.Tensor, cond_input: torch.Tensor, do_classifier_free_guidance: bool + ) -> torch.Tensor: + if do_classifier_free_guidance: + return torch.cat([negative_cond_input, cond_input], dim=0) + else: + return cond_input + + def prepare_inputs(self, cfg_input_mapping: Dict[str, Any], do_classifier_free_guidance: bool) -> Dict[str, Any]: + prepared_inputs = {} + for cfg_input_name, (negative_cond_input, cond_input) in cfg_input_mapping.items(): + prepared_inputs[cfg_input_name] = self.prepare_inputs_for_cfg( + negative_cond_input, cond_input, do_classifier_free_guidance + ) + return prepared_inputs + + def apply_guidance( + self, + model_output: torch.Tensor, + guidance_scale: float, + do_classifier_free_guidance: bool, + guidance_rescale: float = 0.0, + ) -> torch.Tensor: + if not do_classifier_free_guidance: + return model_output + + noise_pred_uncond, noise_pred_text = model_output.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + if guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + return noise_pred + + +class SDXLCustomPipeline( + CustomPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, +): + def __init__(self): + super().__init__() + + @property + def default_sample_size(self): + default_sample_size = 128 + if hasattr(self, "unet") and self.unet is not None: + default_sample_size = self.unet.config.sample_size + return default_sample_size + + @property + def vae_scale_factor(self): + vae_scale_factor = 8 + if hasattr(self, "vae") and self.vae is not None: + vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + return vae_scale_factor + + @property + def num_channels_latents(self): + num_channels_latents = 4 + if hasattr(self, "unet") and self.unet is not None: + num_channels_latents = self.unet.config.in_channels + return num_channels_latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image + def prepare_control_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): + # get the original timestep using init_timestep + if denoising_start is None: + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + else: + # Strength is irrelevant if we directly request a timestep to start at; + # that is, strength is determined by the denoising_start instead. + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_start * self.scheduler.config.num_train_timesteps) + ) + ) + + num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item() + if self.scheduler.order == 2 and num_inference_steps % 2 == 0: + # if the scheduler is a 2nd order scheduler we might have to do +1 + # because `num_inference_steps` might be even given that every timestep + # (except the highest one) is duplicated. If `num_inference_steps` is even it would + # mean that we cut the timesteps in the middle of the denoising step + # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 + # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler + num_inference_steps = num_inference_steps + 1 + + # because t_n+1 >= t_n, we slice the timesteps starting from the end + t_start = len(self.scheduler.timesteps) - num_inference_steps + timesteps = self.scheduler.timesteps[t_start:] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start) + return timesteps, num_inference_steps + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + +@dataclass +class PipelineState: + """ + [`PipelineState`] stores the state of a pipeline. It is used to pass data between pipeline blocks. + """ + + inputs: Dict[str, Any] = field(default_factory=dict) + intermediates: Dict[str, Any] = field(default_factory=dict) + outputs: Dict[str, Any] = field(default_factory=dict) + + def add_input(self, key: str, value: Any): + self.inputs[key] = value + + def add_intermediate(self, key: str, value: Any): + self.intermediates[key] = value + + def add_output(self, value: Any): + self.outputs = value + + def get_input(self, key: str, default: Any = None) -> Any: + return self.inputs.get(key, default) + + def get_intermediate(self, key: str, default: Any = None) -> Any: + return self.intermediates.get(key, default) + + def get_output(self) -> Any: + return self.output + + def to_dict(self) -> Dict[str, Any]: + return {**self.__dict__, "inputs": self.inputs, "intermediates": self.intermediates, "outputs": self.outputs} + + +class PipelineBlock: + components: Dict[str, Any] = {} + auxiliaries: Dict[str, Any] = {} + configs: Dict[str, Any] = {} + required_components: List[str] = [] + required_auxiliaries: List[str] = [] + inputs: List[Tuple[str, Any]] = [] # (input_name, default_value) + intermediates_inputs: List[str] = [] + intermediates_outputs: List[str] = [] + + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + raise NotImplementedError("__call__ method must be implemented in subclasses") + + +class InputStep(PipelineBlock): + inputs = [ + ("prompt", None), + ("prompt_embeds", None), + ] + + intermediates_outputs = ["batch_size"] + + @torch.no_grad() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + prompt = state.get_input("prompt") + prompt_embeds = state.get_input("prompt_embeds") + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + state.add_intermediate("batch_size", batch_size) + + return pipeline, state + + +class TextEncoderStep(PipelineBlock): + inputs = [ + ("prompt", None), + ("prompt_2", None), + ("negative_prompt", None), + ("negative_prompt_2", None), + ("cross_attention_kwargs", None), + ("prompt_embeds", None), + ("negative_prompt_embeds", None), + ("pooled_prompt_embeds", None), + ("negative_pooled_prompt_embeds", None), + ("num_images_per_prompt", 1), + ("guidance_scale", 5.0), + ("clip_skip", None), + ] + + intermediates_outputs = [ + "prompt_embeds", + "negative_prompt_embeds", + "pooled_prompt_embeds", + "negative_pooled_prompt_embeds", + ] + + def __init__( + self, + text_encoder: Optional[CLIPTextModel] = None, + text_encoder_2: Optional[CLIPTextModelWithProjection] = None, + tokenizer: Optional[CLIPTokenizer] = None, + tokenizer_2: Optional[CLIPTokenizer] = None, + force_zeros_for_empty_prompt: bool = True, + ): + if text_encoder is not None: + self.components["text_encoder"] = text_encoder + if text_encoder_2 is not None: + self.components["text_encoder_2"] = text_encoder_2 + if tokenizer is not None: + self.components["tokenizer"] = tokenizer + if tokenizer_2 is not None: + self.components["tokenizer_2"] = tokenizer_2 + + self.configs["force_zeros_for_empty_prompt"] = force_zeros_for_empty_prompt + + @staticmethod + def check_inputs( + pipeline, + prompt, + prompt_2, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + ): + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + @torch.no_grad() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + # Get inputs + prompt = state.get_input("prompt") + prompt_2 = state.get_input("prompt_2") + negative_prompt = state.get_input("negative_prompt") + negative_prompt_2 = state.get_input("negative_prompt_2") + cross_attention_kwargs = state.get_input("cross_attention_kwargs") + prompt_embeds = state.get_input("prompt_embeds") + negative_prompt_embeds = state.get_input("negative_prompt_embeds") + pooled_prompt_embeds = state.get_input("pooled_prompt_embeds") + negative_pooled_prompt_embeds = state.get_input("negative_pooled_prompt_embeds") + num_images_per_prompt = state.get_input("num_images_per_prompt") + guidance_scale = state.get_input("guidance_scale") + clip_skip = state.get_input("clip_skip") + + do_classifier_free_guidance = guidance_scale > 1.0 + device = pipeline._execution_device + + self.check_inputs( + pipeline, + prompt, + prompt_2, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + # Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = pipeline.encode_prompt( + prompt, + prompt_2, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=clip_skip, + ) + # Add outputs + state.add_intermediate("prompt_embeds", prompt_embeds) + state.add_intermediate("negative_prompt_embeds", negative_prompt_embeds) + state.add_intermediate("pooled_prompt_embeds", pooled_prompt_embeds) + state.add_intermediate("negative_pooled_prompt_embeds", negative_pooled_prompt_embeds) + return pipeline, state + + +class SetTimestepsStep(PipelineBlock): + inputs = [ + ("num_inference_steps", 50), + ("timesteps", None), + ("sigmas", None), + ("denoising_end", None), + ] + required_components = ["scheduler"] + intermediates_outputs = ["timesteps", "num_inference_steps"] + + def __init__(self, scheduler=None): + if scheduler is not None: + self.components["scheduler"] = scheduler + + @torch.no_grad() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + num_inference_steps = state.get_input("num_inference_steps") + timesteps = state.get_input("timesteps") + sigmas = state.get_input("sigmas") + denoising_end = state.get_input("denoising_end") + + device = pipeline._execution_device + + timesteps, num_inference_steps = retrieve_timesteps( + pipeline.scheduler, num_inference_steps, device, timesteps, sigmas + ) + + if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: + discrete_timestep_cutoff = int( + round( + pipeline.scheduler.config.num_train_timesteps + - (denoising_end * pipeline.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + state.add_intermediate("timesteps", timesteps) + state.add_intermediate("num_inference_steps", num_inference_steps) + + return pipeline, state + + +class Image2ImagePrepareLatentsStep(PipelineBlock): + intermediates_inputs = ["batch_size", "timesteps", "num_inference_steps"] + intermediates_outputs = ["image", "latents"] + + def __init__(self, vae=None): + if vae is not None: + self.components["vae"] = vae + + @staticmethod + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents with self->pipe + def prepare_latents( + pipe, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + latents_mean = latents_std = None + if hasattr(pipe.vae.config, "latents_mean") and pipe.vae.config.latents_mean is not None: + latents_mean = torch.tensor(pipe.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(pipe.vae.config, "latents_std") and pipe.vae.config.latents_std is not None: + latents_std = torch.tensor(pipe.vae.config.latents_std).view(1, 4, 1, 1) + + # Offload text encoder if `enable_model_cpu_offload` was enabled + if hasattr(pipe, "final_offload_hook") and pipe.final_offload_hook is not None: + pipe.text_encoder_2.to("cpu") + torch.cuda.empty_cache() + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + # make sure the VAE is in float32 mode, as it overflows in float16 + if pipe.vae.config.force_upcast: + image = image.float() + pipe.vae.to(dtype=torch.float32) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + + init_latents = [ + retrieve_latents(pipe.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(pipe.vae.encode(image), generator=generator) + + if pipe.vae.config.force_upcast: + pipe.vae.to(dtype) + + init_latents = init_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=device, dtype=dtype) + latents_std = latents_std.to(device=device, dtype=dtype) + init_latents = (init_latents - latents_mean) * pipe.vae.config.scaling_factor / latents_std + else: + init_latents = pipe.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + if add_noise: + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # get latents + init_latents = pipe.scheduler.add_noise(init_latents, noise, timestep) + + latents = init_latents + + return latents + + @torch.no_grad() + def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: + image = state.get_input("image") + height = state.get_input("height") + width = state.get_input("width") + strength = state.get_input("strength") + denoising_start = state.get_input("denoising_start") + num_images_per_prompt = state.get_input("num_images_per_prompt") + generator = state.get_input("generator") + latents = state.get_input("latents") + + # get intermediates + timesteps = state.get_intermediate("timesteps") + num_inference_steps = state.get_intermediate("num_inference_steps") + batch_size = state.get_intermediate("batch_size") + + device = pipeline._execution_device + dtype = pipeline.vae.dtype + + # 4. Prepare image and controlnet_conditioning_image + image = pipeline.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + + def denoising_value_valid(dnv): + return isinstance(dnv, float) and 0 < dnv < 1 + + timesteps, num_inference_steps = pipeline.get_timesteps( + num_inference_steps, + strength, + device, + denoising_start=denoising_start if denoising_value_valid(denoising_start) else None, + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + add_noise = True if self.denoising_start is None else False + + # 6. Prepare latent variables + if latents is None: + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + dtype, + device, + generator, + add_noise, + ) + + state.add_intermediate("image", image) + state.add_intermediate("latents", latents) + + return pipeline, state + + +class PrepareLatentsStep(PipelineBlock): + inputs = [ + ("height", None), + ("width", None), + ("generator", None), + ("latents", None), + ("num_images_per_prompt", 1), + ] + required_components = ["scheduler"] + intermediates_inputs = ["batch_size"] + intermediates_outputs = ["latents"] + + def __init__(self, scheduler=None): + if scheduler is not None: + self.components["scheduler"] = scheduler + + @staticmethod + def check_inputs(pipeline, height, width): + if height % pipeline.vae_scale_factor != 0 or width % pipeline.vae_scale_factor != 0: + raise ValueError( + f"`height` and `width` have to be divisible by {pipeline.vae_scale_factor} but are {height} and {width}." + ) + + @torch.no_grad() + def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: + latents = state.get_input("latents") + num_images_per_prompt = state.get_input("num_images_per_prompt") + height = state.get_input("height") + width = state.get_input("width") + generator = state.get_input("generator") + + batch_size = state.get_intermediate("batch_size") + prompt_embeds = state.get_intermediate("prompt_embeds", None) + + dtype = prompt_embeds.dtype if prompt_embeds is not None else pipeline.dtype + device = pipeline._execution_device + + height = height or pipeline.default_sample_size * pipeline.vae_scale_factor + width = width or pipeline.default_sample_size * pipeline.vae_scale_factor + + self.check_inputs(pipeline, height, width) + + # 5. Prepare latent variables + + num_channels_latents = pipeline.num_channels_latents + latents = pipeline.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents, + ) + + state.add_intermediate("latents", latents) + + return pipeline, state + + +class PrepareAdditionalConditioningStep(PipelineBlock): + inputs = [ + ("original_size", None), + ("target_size", None), + ("negative_original_size", None), + ("negative_target_size", None), + ("crops_coords_top_left", (0, 0)), + ("negative_crops_coords_top_left", (0, 0)), + ("num_images_per_prompt", 1), + ("guidance_scale", 5.0), + ] + intermediates_inputs = ["latents"] + intermediates_outputs = ["add_time_ids", "negative_add_time_ids", "timestep_cond"] + required_components = ["unet"] + + def __init__(self, unet=None): + if unet is not None: + self.components["unet"] = unet + + @torch.no_grad() + def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: + original_size = state.get_input("original_size") + target_size = state.get_input("target_size") + negative_original_size = state.get_input("negative_original_size") + negative_target_size = state.get_input("negative_target_size") + crops_coords_top_left = state.get_input("crops_coords_top_left") + negative_crops_coords_top_left = state.get_input("negative_crops_coords_top_left") + num_images_per_prompt = state.get_input("num_images_per_prompt") + guidance_scale = state.get_input("guidance_scale") + + latents = state.get_intermediate("latents") + batch_size = state.get_intermediate("batch_size") + pooled_prompt_embeds = state.get_intermediate("pooled_prompt_embeds") + + device = pipeline._execution_device + + height, width = latents.shape[-2:] + height = height * pipeline.vae_scale_factor + width = width * pipeline.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + if hasattr(pipeline, "text_encoder_2") and pipeline.text_encoder_2 is not None: + text_encoder_projection_dim = pipeline.text_encoder_2.config.projection_dim + else: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + + add_time_ids = pipeline._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + pooled_prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1).to(device=device) + + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = pipeline._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + pooled_prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + negative_add_time_ids = negative_add_time_ids.repeat(batch_size * num_images_per_prompt, 1).to(device=device) + + # Optionally get Guidance Scale Embedding for LCM + timestep_cond = None + if ( + hasattr(pipeline, "unet") + and pipeline.unet is not None + and pipeline.unet.config.time_cond_proj_dim is not None + ): + guidance_scale_tensor = torch.tensor(guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = pipeline.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=pipeline.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + state.add_intermediate("add_time_ids", add_time_ids) + state.add_intermediate("negative_add_time_ids", negative_add_time_ids) + state.add_intermediate("timestep_cond", timestep_cond) + return pipeline, state + + +class PrepareGuidance(PipelineBlock): + inputs = [ + ("guidance_scale", 5.0), + ] + intermediates_inputs = [ + "add_time_ids", + "negative_add_time_ids", + "prompt_embeds", + "negative_prompt_embeds", + "pooled_prompt_embeds", + "negative_pooled_prompt_embeds", + ] + intermediates_outputs = ["add_text_embeds", "add_time_ids", "prompt_embeds"] + + def __init__(self): + guider = CFGGuider() + self.auxiliaries["guider"] = guider + + @torch.no_grad() + def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: + guidance_scale = state.get_input("guidance_scale") + + prompt_embeds = state.get_intermediate("prompt_embeds") + negative_prompt_embeds = state.get_intermediate("negative_prompt_embeds") + pooled_prompt_embeds = state.get_intermediate("pooled_prompt_embeds") + negative_pooled_prompt_embeds = state.get_intermediate("negative_pooled_prompt_embeds") + add_time_ids = state.get_intermediate("add_time_ids") + negative_add_time_ids = state.get_intermediate("negative_add_time_ids") + + do_classifier_free_guidance = guidance_scale > 1.0 + guider = pipeline.guider + + # Fetch all model inputs from pipeline_state + conditional_inputs = { + "prompt_embeds": (negative_prompt_embeds, prompt_embeds), + "add_time_ids": (negative_add_time_ids, add_time_ids), + "add_text_embeds": (negative_pooled_prompt_embeds, pooled_prompt_embeds), + } + + # Prepare inputs using the guider + prepared_conditional_inputs = guider.prepare_inputs(conditional_inputs, do_classifier_free_guidance) + + # Add prepared inputs back to the state + state.add_intermediate("add_text_embeds", prepared_conditional_inputs["add_text_embeds"]) + state.add_intermediate("add_time_ids", prepared_conditional_inputs["add_time_ids"]) + state.add_intermediate("prompt_embeds", prepared_conditional_inputs["prompt_embeds"]) + + return pipeline, state + + +class DenoiseStep(PipelineBlock): + inputs = [ + ("guidance_scale", 5.0), + ("guidance_rescale", 0.0), + ("cross_attention_kwargs", None), + ("generator", None), + ("eta", 0.0), + ] + intermediates_inputs = [ + "latents", + "timesteps", + "num_inference_steps", + "add_text_embeds", + "add_time_ids", + "timestep_cond", + "prompt_embeds", + ] + intermediates_outputs = ["latents"] + required_components = ["unet"] + + def __init__(self, unet=None): + if unet is not None: + self.components["unet"] = unet + + @torch.no_grad() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + guidance_scale = state.get_input("guidance_scale") + guidance_rescale = state.get_input("guidance_rescale") + cross_attention_kwargs = state.get_input("cross_attention_kwargs") + generator = state.get_input("generator") + eta = state.get_input("eta") + + latents = state.get_intermediate("latents") + timesteps = state.get_intermediate("timesteps") + num_inference_steps = state.get_intermediate("num_inference_steps") + + add_text_embeds = state.get_intermediate("add_text_embeds") + add_time_ids = state.get_intermediate("add_time_ids") + timestep_cond = state.get_intermediate("timestep_cond") + prompt_embeds = state.get_intermediate("prompt_embeds") + + do_classifier_free_guidance = guidance_scale > 1.0 + + # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta) + num_warmup_steps = max(len(timesteps) - num_inference_steps * pipeline.scheduler.order, 0) + + with pipeline.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = pipeline.guider.prepare_inputs_for_cfg( + latents, latents, do_classifier_free_guidance + ) + latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t) + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + noise_pred = pipeline.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + # perform guidance + noise_pred = pipeline.guider.apply_guidance( + noise_pred, guidance_scale, do_classifier_free_guidance, guidance_rescale + ) + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = pipeline.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): + progress_bar.update() + + state.add_intermediate("latents", latents) + + return pipeline, state + + +class DecodeLatentsStep(PipelineBlock): + inputs = [ + ("output_type", "pil"), + ("return_dict", True), + ] + intermediates_inputs = ["latents"] + + def __init__(self, vae=None, vae_scale_factor=8): + if vae is not None: + self.components["vae"] = vae + image_processor = VaeImageProcessor(vae_scale_factor=8) + self.auxiliaries["image_processor"] = image_processor + + @torch.no_grad() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + output_type = state.get_input("output_type") + return_dict = state.get_input("return_dict") + + latents = state.get_intermediate("latents") + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = pipeline.vae.dtype == torch.float16 and pipeline.vae.config.force_upcast + + if needs_upcasting: + pipeline.upcast_vae() + latents = latents.to(next(iter(pipeline.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != pipeline.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + pipeline.vae = pipeline.vae.to(latents.dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = ( + hasattr(pipeline.vae.config, "latents_mean") and pipeline.vae.config.latents_mean is not None + ) + has_latents_std = ( + hasattr(pipeline.vae.config, "latents_std") and pipeline.vae.config.latents_std is not None + ) + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(pipeline.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / pipeline.vae.config.scaling_factor + latents_mean + else: + latents = latents / pipeline.vae.config.scaling_factor + + image = pipeline.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + pipeline.vae.to(dtype=torch.float16) + else: + image = latents + + # apply watermark if available + if hasattr(pipeline, "watermark") and pipeline.watermark is not None: + image = pipeline.watermark.apply_watermark(image) + + image = pipeline.image_processor.postprocess(image, output_type=output_type) + + if not return_dict: + output = (image,) + else: + output = StableDiffusionXLPipelineOutput(images=image) + + state.add_intermediate("images", image) + state.add_output(output) + + return pipeline, state + + +class PipelineBlockType(Enum): + InputStep = 1 + TextEncoderStep = 2 + SetTimestepsStep = 3 + PrepareLatentsStep = 4 + PrepareAdditionalConditioningStep = 5 + PrepareGuidance = 6 + DenoiseStep = 7 + DecodeLatentsStep = 8 + + +PIPELINE_BLOCKS = { + StableDiffusionXLPipeline: [ + PipelineBlockType.InputStep, + PipelineBlockType.TextEncoderStep, + PipelineBlockType.SetTimestepsStep, + PipelineBlockType.PrepareLatentsStep, + PipelineBlockType.PrepareAdditionalConditioningStep, + PipelineBlockType.PrepareGuidance, + PipelineBlockType.DenoiseStep, + PipelineBlockType.DecodeLatentsStep, + ], +} + + +class CustomPipelineBuilder: + def __init__(self, pipeline_class: str): + if pipeline_class == "SDXL": + self.pipeline = SDXLCustomPipeline() + else: + raise ValueError(f"Pipeline class {pipeline_class} not supported") + self.pipeline_blocks = [] + self.pipeline.builder = self + + def add_blocks(self, pipeline_blocks: Union[PipelineBlock, List[PipelineBlock]]): + if not isinstance(pipeline_blocks, list): + pipeline_blocks = [pipeline_blocks] + + for block in pipeline_blocks: + self.pipeline_blocks.append(block) + self.pipeline.register_modules(**block.components) + self.pipeline.register_to_config(**block.configs) + # Add auxiliaries as attributes to the pipeline + for key, value in block.auxiliaries.items(): + setattr(self.pipeline, key, value) + + for required_component in block.required_components: + if ( + not hasattr(self.pipeline, required_component) + or getattr(self.pipeline, required_component) is None + ): + raise ValueError( + f"Cannot add block {block.__class__.__name__}: Required component {required_component} not found in pipeline" + ) + + for required_auxiliary in block.required_auxiliaries: + if ( + not hasattr(self.pipeline, required_auxiliary) + or getattr(self.pipeline, required_auxiliary) is None + ): + raise ValueError( + f"Cannot add block {block.__class__.__name__}: Required auxiliary {required_auxiliary} not found in pipeline" + ) + + def run_pipeline(self, return_pipeline_state=False, **kwargs): + state = PipelineState() + pipeline = self.pipeline + + # Make a copy of the input kwargs + input_params = kwargs.copy() + + default_params = self.default_call_parameters + + # Add inputs to state, using defaults if not provided + for name, default in default_params.items(): + if name in input_params: + state.add_input(name, input_params.pop(name)) + else: + state.add_input(name, default) + + # Warn about unexpected inputs + if len(input_params) > 0: + logger.warning(f"Unexpected input '{input_params.keys()}' provided. This input will be ignored.") + + # Run the pipeline + with torch.no_grad(): + for block in self.pipeline_blocks: + pipeline, state = block(pipeline, state) + + if return_pipeline_state: + return state + else: + return state.outputs + + @property + def default_call_parameters(self) -> Dict[str, Any]: + params = {} + for block in self.pipeline_blocks: + for name, default in block.inputs: + if name not in params: + params[name] = default + return params + + def __repr__(self): + output = "CustomPipeline Configuration:\n" + output += "==============================\n\n" + + # List the blocks used to build the pipeline + output += "Pipeline Blocks:\n" + output += "----------------\n" + for i, block in enumerate(self.pipeline_blocks, 1): + output += f"{i}. {block.__class__.__name__}\n" + output += "\n" + + # List the components registered in the pipeline + output += "Registered Components:\n" + output += "----------------------\n" + for name, component in self.pipeline.components.items(): + output += f"{name}: {type(component).__name__}\n" + output += "\n" + + # List the default call parameters + output += "Default Call Parameters:\n" + output += "------------------------\n" + params = self.default_call_parameters + for name, default in params.items(): + output += f"{name}: {default!r}\n" + + output += "\nNote: These are the default values. Actual values may be different when running the pipeline." + return output From 52a7f1cb971cd402721f8e11a45f6dbbd12fe42a Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 16 Oct 2024 09:04:32 +0200 Subject: [PATCH 002/170] add dataflow info for each block in builder _repr_ --- .../pipelines/custom_pipeline_builder.py | 104 ++++++++++++++++++ 1 file changed, 104 insertions(+) diff --git a/src/diffusers/pipelines/custom_pipeline_builder.py b/src/diffusers/pipelines/custom_pipeline_builder.py index d3603d088932..6c510a783ecd 100644 --- a/src/diffusers/pipelines/custom_pipeline_builder.py +++ b/src/diffusers/pipelines/custom_pipeline_builder.py @@ -775,6 +775,94 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype latents = latents * self.scheduler.init_noise_sigma return latents + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents + def prepare_latents_img2img( + self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + + # Offload text encoder if `enable_model_cpu_offload` was enabled + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.text_encoder_2.to("cpu") + torch.cuda.empty_cache() + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + + init_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + if self.vae.config.force_upcast: + self.vae.to(dtype) + + init_latents = init_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=device, dtype=dtype) + latents_std = latents_std.to(device=device, dtype=dtype) + init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std + else: + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + if add_noise: + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + + latents = init_latents + + return latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature @@ -1743,6 +1831,22 @@ def __repr__(self): output += "----------------\n" for i, block in enumerate(self.pipeline_blocks, 1): output += f"{i}. {block.__class__.__name__}\n" + + intermediates_str = "" + if hasattr(block, 'intermediates_inputs'): + intermediates_str += f"{', '.join(block.intermediates_inputs)}" + + if hasattr(block, 'intermediates_outputs'): + if intermediates_str: + intermediates_str += " -> " + else: + intermediates_str += "-> " + intermediates_str += f"{', '.join(block.intermediates_outputs)}" + + if intermediates_str: + output += f" {intermediates_str}\n" + + output += "\n" output += "\n" # List the components registered in the pipeline From e8d0980f9fe22856c50da1749baf520b0717fff6 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 16 Oct 2024 20:56:39 +0200 Subject: [PATCH 003/170] add img2img support - output does not match with non-modular pipeline completely yet (look into later) --- .../pipelines/custom_pipeline_builder.py | 316 ++++++++++++------ 1 file changed, 210 insertions(+), 106 deletions(-) diff --git a/src/diffusers/pipelines/custom_pipeline_builder.py b/src/diffusers/pipelines/custom_pipeline_builder.py index 6c510a783ecd..e17b39c8ef65 100644 --- a/src/diffusers/pipelines/custom_pipeline_builder.py +++ b/src/diffusers/pipelines/custom_pipeline_builder.py @@ -375,6 +375,58 @@ def _get_add_time_ids( add_time_ids = torch.tensor([add_time_ids], dtype=dtype) return add_time_ids + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids + def _get_add_time_ids_img2img( + self, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype, + text_encoder_projection_dim=None, + ): + if self.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list( + negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) + ) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype @@ -862,7 +914,6 @@ def prepare_latents_img2img( return latents - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature @@ -1209,141 +1260,105 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: return pipeline, state +class Image2ImageSetTimestepsStep(PipelineBlock): + inputs = [ + ("num_inference_steps", 50), + ("timesteps", None), + ("sigmas", None), + ("denoising_end", None), + ("strength", 0.3), + ("denoising_start", None), + ("num_images_per_prompt", 1), + ] + required_components = ["scheduler"] + intermediates_outputs = ["timesteps", "num_inference_steps", "latent_timestep"] -class Image2ImagePrepareLatentsStep(PipelineBlock): - intermediates_inputs = ["batch_size", "timesteps", "num_inference_steps"] - intermediates_outputs = ["image", "latents"] - - def __init__(self, vae=None): - if vae is not None: - self.components["vae"] = vae - - @staticmethod - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents with self->pipe - def prepare_latents( - pipe, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True - ): - if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): - raise ValueError( - f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" - ) + def __init__(self, scheduler=None): + if scheduler is not None: + self.components["scheduler"] = scheduler - latents_mean = latents_std = None - if hasattr(pipe.vae.config, "latents_mean") and pipe.vae.config.latents_mean is not None: - latents_mean = torch.tensor(pipe.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(pipe.vae.config, "latents_std") and pipe.vae.config.latents_std is not None: - latents_std = torch.tensor(pipe.vae.config.latents_std).view(1, 4, 1, 1) + @torch.no_grad() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + num_inference_steps = state.get_input("num_inference_steps") + timesteps = state.get_input("timesteps") + sigmas = state.get_input("sigmas") + denoising_end = state.get_input("denoising_end") + strength = state.get_input("strength") + denoising_start = state.get_input("denoising_start") + num_images_per_prompt = state.get_input("num_images_per_prompt") - # Offload text encoder if `enable_model_cpu_offload` was enabled - if hasattr(pipe, "final_offload_hook") and pipe.final_offload_hook is not None: - pipe.text_encoder_2.to("cpu") - torch.cuda.empty_cache() + batch_size = state.get_intermediate("batch_size") - image = image.to(device=device, dtype=dtype) + device = pipeline._execution_device - batch_size = batch_size * num_images_per_prompt + def denoising_value_valid(dnv): + return isinstance(dnv, float) and 0 < dnv < 1 - if image.shape[1] == 4: - init_latents = image + timesteps, num_inference_steps = retrieve_timesteps( + pipeline.scheduler, num_inference_steps, device, timesteps, sigmas + ) - else: - # make sure the VAE is in float32 mode, as it overflows in float16 - if pipe.vae.config.force_upcast: - image = image.float() - pipe.vae.to(dtype=torch.float32) + timesteps, num_inference_steps = pipeline.get_timesteps( + num_inference_steps, + strength, + device, + denoising_start=denoising_start if denoising_value_valid(denoising_start) else None, + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." + if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: + discrete_timestep_cutoff = int( + round( + pipeline.scheduler.config.num_train_timesteps + - (denoising_end * pipeline.scheduler.config.num_train_timesteps) ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] - elif isinstance(generator, list): - if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: - image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) - elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " - ) - - init_latents = [ - retrieve_latents(pipe.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(batch_size) - ] - init_latents = torch.cat(init_latents, dim=0) - else: - init_latents = retrieve_latents(pipe.vae.encode(image), generator=generator) - - if pipe.vae.config.force_upcast: - pipe.vae.to(dtype) - - init_latents = init_latents.to(dtype) - if latents_mean is not None and latents_std is not None: - latents_mean = latents_mean.to(device=device, dtype=dtype) - latents_std = latents_std.to(device=device, dtype=dtype) - init_latents = (init_latents - latents_mean) * pipe.vae.config.scaling_factor / latents_std - else: - init_latents = pipe.vae.config.scaling_factor * init_latents + state.add_intermediate("timesteps", timesteps) + state.add_intermediate("num_inference_steps", num_inference_steps) + state.add_intermediate("latent_timestep", latent_timestep) - if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: - # expand init_latents for batch_size - additional_image_per_prompt = batch_size // init_latents.shape[0] - init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) - elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." - ) - else: - init_latents = torch.cat([init_latents], dim=0) + return pipeline, state - if add_noise: - shape = init_latents.shape - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - # get latents - init_latents = pipe.scheduler.add_noise(init_latents, noise, timestep) - latents = init_latents +class Image2ImagePrepareLatentsStep(PipelineBlock): + inputs = [ + ("image", None), + ("num_images_per_prompt", 1), + ("generator", None), + ("latents", None), + ] + intermediates_inputs = ["batch_size", "timesteps", "num_inference_steps"] + intermediates_outputs = ["latents", "timesteps", "num_inference_steps"] - return latents + def __init__(self, vae=None, vae_scale_factor=8): + if vae is not None: + self.components["vae"] = vae + self.image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) @torch.no_grad() def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: image = state.get_input("image") - height = state.get_input("height") - width = state.get_input("width") - strength = state.get_input("strength") - denoising_start = state.get_input("denoising_start") num_images_per_prompt = state.get_input("num_images_per_prompt") generator = state.get_input("generator") latents = state.get_input("latents") - + denoising_start = state.get_input("denoising_start") # get intermediates - timesteps = state.get_intermediate("timesteps") - num_inference_steps = state.get_intermediate("num_inference_steps") batch_size = state.get_intermediate("batch_size") + latent_timestep = state.get_intermediate("latent_timestep") device = pipeline._execution_device dtype = pipeline.vae.dtype - # 4. Prepare image and controlnet_conditioning_image - image = pipeline.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image = pipeline.image_processor.preprocess(image) - def denoising_value_valid(dnv): - return isinstance(dnv, float) and 0 < dnv < 1 - timesteps, num_inference_steps = pipeline.get_timesteps( - num_inference_steps, - strength, - device, - denoising_start=denoising_start if denoising_value_valid(denoising_start) else None, - ) - latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + add_noise = True if denoising_start is None else False - add_noise = True if self.denoising_start is None else False - - # 6. Prepare latent variables if latents is None: - latents = self.prepare_latents( + latents = pipeline.prepare_latents_img2img( image, latent_timestep, batch_size, @@ -1354,7 +1369,6 @@ def denoising_value_valid(dnv): add_noise, ) - state.add_intermediate("image", image) state.add_intermediate("latents", latents) return pipeline, state @@ -1507,6 +1521,96 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin state.add_intermediate("timestep_cond", timestep_cond) return pipeline, state +class Image2ImagePrepareAdditionalConditioningStep(PipelineBlock): + inputs = [ + ("original_size", None), + ("target_size", None), + ("negative_original_size", None), + ("negative_target_size", None), + ("crops_coords_top_left", (0, 0)), + ("negative_crops_coords_top_left", (0, 0)), + ("num_images_per_prompt", 1), + ("guidance_scale", 5.0), + ("aesthetic_score", 6.0), + ("negative_aesthetic_score", 2.0), + ] + intermediates_inputs = ["latents"] + intermediates_outputs = ["add_time_ids", "negative_add_time_ids", "timestep_cond"] + required_components = ["unet"] + + def __init__(self, unet=None, requires_aesthetics_score=False): + if unet is not None: + self.components["unet"] = unet + if requires_aesthetics_score is not None: + self.configs["requires_aesthetics_score"] = requires_aesthetics_score + + @torch.no_grad() + def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: + original_size = state.get_input("original_size") + target_size = state.get_input("target_size") + negative_original_size = state.get_input("negative_original_size") + negative_target_size = state.get_input("negative_target_size") + crops_coords_top_left = state.get_input("crops_coords_top_left") + negative_crops_coords_top_left = state.get_input("negative_crops_coords_top_left") + num_images_per_prompt = state.get_input("num_images_per_prompt") + guidance_scale = state.get_input("guidance_scale") + aesthetic_score = state.get_input("aesthetic_score") + negative_aesthetic_score = state.get_input("negative_aesthetic_score") + + latents = state.get_intermediate("latents") + batch_size = state.get_intermediate("batch_size") + pooled_prompt_embeds = state.get_intermediate("pooled_prompt_embeds") + + device = pipeline._execution_device + + height, width = latents.shape[-2:] + height = height * pipeline.vae_scale_factor + width = width * pipeline.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + if negative_original_size is None: + negative_original_size = original_size + if negative_target_size is None: + negative_target_size = target_size + + if hasattr(pipeline, "text_encoder_2") and pipeline.text_encoder_2 is not None: + text_encoder_projection_dim = pipeline.text_encoder_2.config.projection_dim + else: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + + add_time_ids, negative_add_time_ids = pipeline._get_add_time_ids_img2img( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=pooled_prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1).to(device=device) + negative_add_time_ids = negative_add_time_ids.repeat(batch_size * num_images_per_prompt, 1).to(device=device) + + # Optionally get Guidance Scale Embedding for LCM + timestep_cond = None + if ( + hasattr(pipeline, "unet") + and pipeline.unet is not None + and pipeline.unet.config.time_cond_proj_dim is not None + ): + guidance_scale_tensor = torch.tensor(guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = pipeline.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=pipeline.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + state.add_intermediate("add_time_ids", add_time_ids) + state.add_intermediate("negative_add_time_ids", negative_add_time_ids) + state.add_intermediate("timestep_cond", timestep_cond) + return pipeline, state class PrepareGuidance(PipelineBlock): inputs = [ From ad3f9a26c0055ae53205183868b737f199153f93 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 17 Oct 2024 05:47:15 +0200 Subject: [PATCH 004/170] update img2img, result match --- .../pipelines/custom_pipeline_builder.py | 493 ++++++++++++------ 1 file changed, 327 insertions(+), 166 deletions(-) diff --git a/src/diffusers/pipelines/custom_pipeline_builder.py b/src/diffusers/pipelines/custom_pipeline_builder.py index e17b39c8ef65..f1ba2560b87d 100644 --- a/src/diffusers/pipelines/custom_pipeline_builder.py +++ b/src/diffusers/pipelines/custom_pipeline_builder.py @@ -1015,26 +1015,81 @@ def to_dict(self) -> Dict[str, Any]: class PipelineBlock: - components: Dict[str, Any] = {} - auxiliaries: Dict[str, Any] = {} - configs: Dict[str, Any] = {} - required_components: List[str] = [] - required_auxiliaries: List[str] = [] - inputs: List[Tuple[str, Any]] = [] # (input_name, default_value) - intermediates_inputs: List[str] = [] - intermediates_outputs: List[str] = [] + + @property + def optional_components(self) -> List[str]: + return [] + + @property + def required_components(self) -> List[str]: + return [] + + @property + def required_auxiliaries(self) -> List[str]: + return [] + + @property + def inputs(self) -> Tuple[Tuple[str, Any], ...]: + # (input_name, default_value) + return () + + @property + def intermediates_inputs(self) -> List[str]: + return [] + + @property + def intermediates_outputs(self) -> List[str]: + return [] + + def __init__(self, **kwargs): + self.components: Dict[str, Any] = {} + self.auxiliaries: Dict[str, Any] = {} + self.configs: Dict[str, Any] = {} + + # Process kwargs + for key, value in kwargs.items(): + if key in self.required_components or key in self.optional_components: + self.components[key] = value + elif key in self.required_auxiliaries: + self.auxiliaries[key] = value + else: + self.configs[key] = value + def __call__(self, pipeline, state: PipelineState) -> PipelineState: raise NotImplementedError("__call__ method must be implemented in subclasses") + def __repr__(self): + class_name = self.__class__.__name__ + components = ", ".join(f"{k}={type(v).__name__}" for k, v in self.components.items()) + auxiliaries = ", ".join(f"{k}={type(v).__name__}" for k, v in self.auxiliaries.items()) + configs = ", ".join(f"{k}={v}" for k, v in self.configs.items()) + inputs = ", ".join(f"{name}={default}" for name, default in self.inputs) + intermediates_inputs = ", ".join(self.intermediates_inputs) + intermediates_outputs = ", ".join(self.intermediates_outputs) + + return (f"{class_name}(\n" + f" components: {components}\n" + f" auxiliaries: {auxiliaries}\n" + f" configs: {configs}\n" + f" inputs: {inputs}\n" + f" intermediates_inputs: {intermediates_inputs}\n" + f" intermediates_outputs: {intermediates_outputs}\n" + f")") + class InputStep(PipelineBlock): - inputs = [ - ("prompt", None), - ("prompt_embeds", None), - ] + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + ("prompt", None), + ("prompt_embeds", None), + ] - intermediates_outputs = ["batch_size"] + @property + def intermediates_outputs(self) -> List[str]: + return ["batch_size"] @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: @@ -1054,27 +1109,35 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class TextEncoderStep(PipelineBlock): - inputs = [ - ("prompt", None), - ("prompt_2", None), - ("negative_prompt", None), - ("negative_prompt_2", None), - ("cross_attention_kwargs", None), - ("prompt_embeds", None), - ("negative_prompt_embeds", None), - ("pooled_prompt_embeds", None), - ("negative_pooled_prompt_embeds", None), - ("num_images_per_prompt", 1), - ("guidance_scale", 5.0), - ("clip_skip", None), - ] - - intermediates_outputs = [ - "prompt_embeds", - "negative_prompt_embeds", - "pooled_prompt_embeds", - "negative_pooled_prompt_embeds", - ] + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + ("prompt", None), + ("prompt_2", None), + ("negative_prompt", None), + ("negative_prompt_2", None), + ("cross_attention_kwargs", None), + ("prompt_embeds", None), + ("negative_prompt_embeds", None), + ("pooled_prompt_embeds", None), + ("negative_pooled_prompt_embeds", None), + ("num_images_per_prompt", 1), + ("guidance_scale", 5.0), + ("clip_skip", None), + ] + + @property + def intermediates_outputs(self) -> List[str]: + return [ + "prompt_embeds", + "negative_prompt_embeds", + "pooled_prompt_embeds", + "negative_pooled_prompt_embeds", + ] + + @property + def optional_components(self) -> List[str]: + return ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"] def __init__( self, @@ -1084,16 +1147,13 @@ def __init__( tokenizer_2: Optional[CLIPTokenizer] = None, force_zeros_for_empty_prompt: bool = True, ): - if text_encoder is not None: - self.components["text_encoder"] = text_encoder - if text_encoder_2 is not None: - self.components["text_encoder_2"] = text_encoder_2 - if tokenizer is not None: - self.components["tokenizer"] = tokenizer - if tokenizer_2 is not None: - self.components["tokenizer_2"] = tokenizer_2 - - self.configs["force_zeros_for_empty_prompt"] = force_zeros_for_empty_prompt + super().__init__( + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + force_zeros_for_empty_prompt=force_zeros_for_empty_prompt, + ) @staticmethod def check_inputs( @@ -1219,18 +1279,25 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class SetTimestepsStep(PipelineBlock): - inputs = [ - ("num_inference_steps", 50), - ("timesteps", None), - ("sigmas", None), - ("denoising_end", None), - ] - required_components = ["scheduler"] - intermediates_outputs = ["timesteps", "num_inference_steps"] + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + ("num_inference_steps", 50), + ("timesteps", None), + ("sigmas", None), + ("denoising_end", None), + ] + + @property + def required_components(self) -> List[str]: + return ["scheduler"] + + @property + def intermediates_outputs(self) -> List[str]: + return ["timesteps", "num_inference_steps"] def __init__(self, scheduler=None): - if scheduler is not None: - self.components["scheduler"] = scheduler + super().__init__(scheduler=scheduler) @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: @@ -1260,22 +1327,30 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: return pipeline, state + class Image2ImageSetTimestepsStep(PipelineBlock): - inputs = [ - ("num_inference_steps", 50), - ("timesteps", None), - ("sigmas", None), - ("denoising_end", None), - ("strength", 0.3), - ("denoising_start", None), - ("num_images_per_prompt", 1), - ] - required_components = ["scheduler"] - intermediates_outputs = ["timesteps", "num_inference_steps", "latent_timestep"] + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + ("num_inference_steps", 50), + ("timesteps", None), + ("sigmas", None), + ("denoising_end", None), + ("strength", 0.3), + ("denoising_start", None), + ("num_images_per_prompt", 1), + ] + + @property + def required_components(self) -> List[str]: + return ["scheduler"] + + @property + def intermediates_outputs(self) -> List[str]: + return ["timesteps", "num_inference_steps", "latent_timestep"] def __init__(self, scheduler=None): - if scheduler is not None: - self.components["scheduler"] = scheduler + super().__init__(scheduler=scheduler) @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: @@ -1324,19 +1399,35 @@ def denoising_value_valid(dnv): class Image2ImagePrepareLatentsStep(PipelineBlock): - inputs = [ - ("image", None), - ("num_images_per_prompt", 1), - ("generator", None), - ("latents", None), - ] - intermediates_inputs = ["batch_size", "timesteps", "num_inference_steps"] - intermediates_outputs = ["latents", "timesteps", "num_inference_steps"] + + @property + def required_auxiliaries(self) -> List[str]: + return ["image_processor"] + + @property + def required_components(self) -> List[str]: + return ["vae"] + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + ("image", None), + ("num_images_per_prompt", 1), + ("generator", None), + ("latents", None), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return ["batch_size", "timesteps", "num_inference_steps"] + + @property + def intermediates_outputs(self) -> List[str]: + return ["latents", "timesteps", "num_inference_steps"] def __init__(self, vae=None, vae_scale_factor=8): - if vae is not None: - self.components["vae"] = vae - self.image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) + image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) + super().__init__(vae=vae, image_processor=image_processor, vae_scale_factor=vae_scale_factor) @torch.no_grad() def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: @@ -1375,20 +1466,31 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin class PrepareLatentsStep(PipelineBlock): - inputs = [ - ("height", None), - ("width", None), - ("generator", None), - ("latents", None), - ("num_images_per_prompt", 1), - ] - required_components = ["scheduler"] - intermediates_inputs = ["batch_size"] - intermediates_outputs = ["latents"] + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + ("height", None), + ("width", None), + ("generator", None), + ("latents", None), + ("num_images_per_prompt", 1), + ] + + @property + def required_components(self) -> List[str]: + return ["scheduler"] + + @property + def intermediates_inputs(self) -> List[str]: + return ["batch_size"] + + @property + def intermediates_outputs(self) -> List[str]: + return ["latents"] def __init__(self, scheduler=None): - if scheduler is not None: - self.components["scheduler"] = scheduler + super().__init__(scheduler=scheduler) @staticmethod def check_inputs(pipeline, height, width): @@ -1436,23 +1538,33 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin class PrepareAdditionalConditioningStep(PipelineBlock): - inputs = [ - ("original_size", None), - ("target_size", None), - ("negative_original_size", None), - ("negative_target_size", None), - ("crops_coords_top_left", (0, 0)), - ("negative_crops_coords_top_left", (0, 0)), - ("num_images_per_prompt", 1), - ("guidance_scale", 5.0), - ] - intermediates_inputs = ["latents"] - intermediates_outputs = ["add_time_ids", "negative_add_time_ids", "timestep_cond"] - required_components = ["unet"] + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + ("original_size", None), + ("target_size", None), + ("negative_original_size", None), + ("negative_target_size", None), + ("crops_coords_top_left", (0, 0)), + ("negative_crops_coords_top_left", (0, 0)), + ("num_images_per_prompt", 1), + ("guidance_scale", 5.0), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return ["latents"] + + @property + def intermediates_outputs(self) -> List[str]: + return ["add_time_ids", "negative_add_time_ids", "timestep_cond"] + + @property + def required_components(self) -> List[str]: + return ["unet"] def __init__(self, unet=None): - if unet is not None: - self.components["unet"] = unet + super().__init__(unet=unet) @torch.no_grad() def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: @@ -1521,28 +1633,37 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin state.add_intermediate("timestep_cond", timestep_cond) return pipeline, state + class Image2ImagePrepareAdditionalConditioningStep(PipelineBlock): - inputs = [ - ("original_size", None), - ("target_size", None), - ("negative_original_size", None), - ("negative_target_size", None), - ("crops_coords_top_left", (0, 0)), - ("negative_crops_coords_top_left", (0, 0)), - ("num_images_per_prompt", 1), - ("guidance_scale", 5.0), - ("aesthetic_score", 6.0), - ("negative_aesthetic_score", 2.0), - ] - intermediates_inputs = ["latents"] - intermediates_outputs = ["add_time_ids", "negative_add_time_ids", "timestep_cond"] - required_components = ["unet"] + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + ("original_size", None), + ("target_size", None), + ("negative_original_size", None), + ("negative_target_size", None), + ("crops_coords_top_left", (0, 0)), + ("negative_crops_coords_top_left", (0, 0)), + ("num_images_per_prompt", 1), + ("guidance_scale", 5.0), + ("aesthetic_score", 6.0), + ("negative_aesthetic_score", 2.0), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return ["latents"] + + @property + def intermediates_outputs(self) -> List[str]: + return ["add_time_ids", "negative_add_time_ids", "timestep_cond"] + + @property + def required_components(self) -> List[str]: + return ["unet"] def __init__(self, unet=None, requires_aesthetics_score=False): - if unet is not None: - self.components["unet"] = unet - if requires_aesthetics_score is not None: - self.configs["requires_aesthetics_score"] = requires_aesthetics_score + super().__init__(unet=unet, requires_aesthetics_score=requires_aesthetics_score) @torch.no_grad() def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: @@ -1612,23 +1733,36 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin state.add_intermediate("timestep_cond", timestep_cond) return pipeline, state + class PrepareGuidance(PipelineBlock): - inputs = [ - ("guidance_scale", 5.0), - ] - intermediates_inputs = [ - "add_time_ids", - "negative_add_time_ids", - "prompt_embeds", - "negative_prompt_embeds", - "pooled_prompt_embeds", - "negative_pooled_prompt_embeds", - ] - intermediates_outputs = ["add_text_embeds", "add_time_ids", "prompt_embeds"] + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + ("guidance_scale", 5.0), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + "add_time_ids", + "negative_add_time_ids", + "prompt_embeds", + "negative_prompt_embeds", + "pooled_prompt_embeds", + "negative_pooled_prompt_embeds", + ] + + @property + def intermediates_outputs(self) -> List[str]: + return ["add_text_embeds", "add_time_ids", "prompt_embeds"] + + @property + def required_auxiliaries(self) -> List[str]: + return ["guider"] def __init__(self): guider = CFGGuider() - self.auxiliaries["guider"] = guider + super().__init__(guider=guider) @torch.no_grad() def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: @@ -1663,28 +1797,38 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin class DenoiseStep(PipelineBlock): - inputs = [ - ("guidance_scale", 5.0), - ("guidance_rescale", 0.0), - ("cross_attention_kwargs", None), - ("generator", None), - ("eta", 0.0), - ] - intermediates_inputs = [ - "latents", - "timesteps", - "num_inference_steps", - "add_text_embeds", - "add_time_ids", - "timestep_cond", - "prompt_embeds", - ] - intermediates_outputs = ["latents"] - required_components = ["unet"] + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + ("guidance_scale", 5.0), + ("guidance_rescale", 0.0), + ("cross_attention_kwargs", None), + ("generator", None), + ("eta", 0.0), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + "latents", + "timesteps", + "num_inference_steps", + "add_text_embeds", + "add_time_ids", + "timestep_cond", + "prompt_embeds", + ] + + @property + def intermediates_outputs(self) -> List[str]: + return ["latents"] + + @property + def required_components(self) -> List[str]: + return ["unet"] def __init__(self, unet=None): - if unet is not None: - self.components["unet"] = unet + super().__init__(unet=unet) @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: @@ -1748,17 +1892,28 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class DecodeLatentsStep(PipelineBlock): - inputs = [ - ("output_type", "pil"), - ("return_dict", True), - ] - intermediates_inputs = ["latents"] + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + ("output_type", "pil"), + ("return_dict", True), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return ["latents"] + + @property + def optional_components(self) -> List[str]: + return ["vae"] + + @property + def required_auxiliaries(self) -> List[str]: + return ["image_processor"] def __init__(self, vae=None, vae_scale_factor=8): - if vae is not None: - self.components["vae"] = vae image_processor = VaeImageProcessor(vae_scale_factor=8) - self.auxiliaries["image_processor"] = image_processor + super().__init__(vae=vae, vae_scale_factor=vae_scale_factor, image_processor=image_processor) @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: @@ -1863,7 +2018,12 @@ def add_blocks(self, pipeline_blocks: Union[PipelineBlock, List[PipelineBlock]]) for block in pipeline_blocks: self.pipeline_blocks.append(block) - self.pipeline.register_modules(**block.components) + # filter out components that already exist in the pipeline + components_to_register = {} + for k, v in block.components.items(): + if not hasattr(self.pipeline, k) or v is not None: + components_to_register[k] = v + self.pipeline.register_modules(**components_to_register) self.pipeline.register_to_config(**block.configs) # Add auxiliaries as attributes to the pipeline for key, value in block.auxiliaries.items(): @@ -1917,6 +2077,7 @@ def run_pipeline(self, return_pipeline_state=False, **kwargs): else: return state.outputs + @property def default_call_parameters(self) -> Dict[str, Any]: params = {} From ddea157979e0912050634fcb2b26dd5b8fd7fbe1 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 17 Oct 2024 20:02:36 +0200 Subject: [PATCH 005/170] add from_pipe + run_blocks --- .../pipelines/custom_pipeline_builder.py | 320 +++++++++++------- 1 file changed, 204 insertions(+), 116 deletions(-) diff --git a/src/diffusers/pipelines/custom_pipeline_builder.py b/src/diffusers/pipelines/custom_pipeline_builder.py index f1ba2560b87d..fa4a6bb2be8e 100644 --- a/src/diffusers/pipelines/custom_pipeline_builder.py +++ b/src/diffusers/pipelines/custom_pipeline_builder.py @@ -913,7 +913,7 @@ def prepare_latents_img2img( latents = init_latents return latents - + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature @@ -998,8 +998,8 @@ def add_input(self, key: str, value: Any): def add_intermediate(self, key: str, value: Any): self.intermediates[key] = value - def add_output(self, value: Any): - self.outputs = value + def add_output(self, key: str, value: Any): + self.outputs[key] = value def get_input(self, key: str, default: Any = None) -> Any: return self.inputs.get(key, default) @@ -1007,26 +1007,38 @@ def get_input(self, key: str, default: Any = None) -> Any: def get_intermediate(self, key: str, default: Any = None) -> Any: return self.intermediates.get(key, default) - def get_output(self) -> Any: - return self.output + def get_output(self, key: str, default: Any = None) -> Any: + return self.outputs.get(key, default) def to_dict(self) -> Dict[str, Any]: return {**self.__dict__, "inputs": self.inputs, "intermediates": self.intermediates, "outputs": self.outputs} + def __repr__(self): + def format_value(v): + if hasattr(v, "shape") and hasattr(v, "dtype"): + return f"Tensor(\n dtype={v.dtype}, shape={v.shape}\n {v})" + elif isinstance(v, list) and len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): + return f"[Tensor(\n dtype={v[0].dtype}, shape={v[0].shape}\n {v[0]}), ...]" + else: + return repr(v) + + inputs = "\n".join(f" {k}: {format_value(v)}" for k, v in self.inputs.items()) + intermediates = "\n".join(f" {k}: {format_value(v)}" for k, v in self.intermediates.items()) + outputs = "\n".join(f" {k}: {format_value(v)}" for k, v in self.outputs.items()) + + return ( + f"PipelineState(\n" + f" inputs={{\n{inputs}\n }},\n" + f" intermediates={{\n{intermediates}\n }},\n" + f" outputs={{\n{outputs}\n }}\n" + f")" + ) -class PipelineBlock: - - @property - def optional_components(self) -> List[str]: - return [] - - @property - def required_components(self) -> List[str]: - return [] - @property - def required_auxiliaries(self) -> List[str]: - return [] +class PipelineBlock: + optional_components = [] + required_components = [] + required_auxiliaries = [] @property def inputs(self) -> Tuple[Tuple[str, Any], ...]: @@ -1055,6 +1067,45 @@ def __init__(self, **kwargs): else: self.configs[key] = value + @classmethod + def from_pipe(cls, pipe: DiffusionPipeline): + """ + Create a PipelineBlock instance from a diffusion pipeline object. + + Args: + pipe: A `[DiffusionPipeline]` object. + + Returns: + PipelineBlock: An instance initialized with the pipeline's components and configurations. + """ + kwargs = {} + + # Add components + for component_name, component in pipe.components.items(): + if component_name in cls.required_components or component_name in cls.optional_components: + kwargs[component_name] = component + + # Add config items that are in the __init__ signature + init_params = inspect.signature(cls.__init__).parameters + for config_name in pipe.config.keys(): + if config_name in init_params and config_name not in kwargs: + kwargs[config_name] = pipe.config[config_name] + # Check for required auxiliaries + for aux_name in cls.required_auxiliaries: + if hasattr(pipe, aux_name): + kwargs[aux_name] = getattr(pipe, aux_name) + + # Add any remaining relevant attributes + for attr_name in dir(pipe): + if ( + not attr_name.startswith("_") + and attr_name not in kwargs + and attr_name not in ["components", "config"] + and attr_name in init_params + ): + kwargs[attr_name] = getattr(pipe, attr_name) + + return cls(**kwargs) def __call__(self, pipeline, state: PipelineState) -> PipelineState: raise NotImplementedError("__call__ method must be implemented in subclasses") @@ -1068,18 +1119,19 @@ def __repr__(self): intermediates_inputs = ", ".join(self.intermediates_inputs) intermediates_outputs = ", ".join(self.intermediates_outputs) - return (f"{class_name}(\n" - f" components: {components}\n" - f" auxiliaries: {auxiliaries}\n" - f" configs: {configs}\n" - f" inputs: {inputs}\n" - f" intermediates_inputs: {intermediates_inputs}\n" - f" intermediates_outputs: {intermediates_outputs}\n" - f")") + return ( + f"{class_name}(\n" + f" components: {components}\n" + f" auxiliaries: {auxiliaries}\n" + f" configs: {configs}\n" + f" inputs: {inputs}\n" + f" intermediates_inputs: {intermediates_inputs}\n" + f" intermediates_outputs: {intermediates_outputs}\n" + f")" + ) class InputStep(PipelineBlock): - @property def inputs(self) -> List[Tuple[str, Any]]: return [ @@ -1109,6 +1161,8 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class TextEncoderStep(PipelineBlock): + optional_components = ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"] + @property def inputs(self) -> List[Tuple[str, Any]]: return [ @@ -1134,10 +1188,6 @@ def intermediates_outputs(self) -> List[str]: "pooled_prompt_embeds", "negative_pooled_prompt_embeds", ] - - @property - def optional_components(self) -> List[str]: - return ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"] def __init__( self, @@ -1279,6 +1329,8 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class SetTimestepsStep(PipelineBlock): + required_components = ["scheduler"] + @property def inputs(self) -> List[Tuple[str, Any]]: return [ @@ -1288,10 +1340,6 @@ def inputs(self) -> List[Tuple[str, Any]]: ("denoising_end", None), ] - @property - def required_components(self) -> List[str]: - return ["scheduler"] - @property def intermediates_outputs(self) -> List[str]: return ["timesteps", "num_inference_steps"] @@ -1329,6 +1377,8 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class Image2ImageSetTimestepsStep(PipelineBlock): + required_components = ["scheduler"] + @property def inputs(self) -> List[Tuple[str, Any]]: return [ @@ -1341,10 +1391,6 @@ def inputs(self) -> List[Tuple[str, Any]]: ("num_images_per_prompt", 1), ] - @property - def required_components(self) -> List[str]: - return ["scheduler"] - @property def intermediates_outputs(self) -> List[str]: return ["timesteps", "num_inference_steps", "latent_timestep"] @@ -1399,15 +1445,9 @@ def denoising_value_valid(dnv): class Image2ImagePrepareLatentsStep(PipelineBlock): - - @property - def required_auxiliaries(self) -> List[str]: - return ["image_processor"] - - @property - def required_components(self) -> List[str]: - return ["vae"] - + required_components = ["vae"] + required_auxiliaries = ["image_processor"] + @property def inputs(self) -> List[Tuple[str, Any]]: return [ @@ -1415,6 +1455,8 @@ def inputs(self) -> List[Tuple[str, Any]]: ("num_images_per_prompt", 1), ("generator", None), ("latents", None), + ("device", None), + ("dtype", None), ] @property @@ -1425,8 +1467,9 @@ def intermediates_inputs(self) -> List[str]: def intermediates_outputs(self) -> List[str]: return ["latents", "timesteps", "num_inference_steps"] - def __init__(self, vae=None, vae_scale_factor=8): - image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) + def __init__(self, vae=None, image_processor=None, vae_scale_factor=8): + if image_processor is None: + image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) super().__init__(vae=vae, image_processor=image_processor, vae_scale_factor=vae_scale_factor) @torch.no_grad() @@ -1436,16 +1479,18 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin generator = state.get_input("generator") latents = state.get_input("latents") denoising_start = state.get_input("denoising_start") + device = state.get_input("device") + dtype = state.get_input("dtype") + # get intermediates batch_size = state.get_intermediate("batch_size") latent_timestep = state.get_intermediate("latent_timestep") - device = pipeline._execution_device - dtype = pipeline.vae.dtype + device = pipeline._execution_device if device is None else device + dtype = pipeline.vae.dtype if dtype is None else dtype image = pipeline.image_processor.preprocess(image) - add_noise = True if denoising_start is None else False if latents is None: @@ -1466,7 +1511,8 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin class PrepareLatentsStep(PipelineBlock): - + required_components = ["scheduler"] + @property def inputs(self) -> List[Tuple[str, Any]]: return [ @@ -1475,16 +1521,14 @@ def inputs(self) -> List[Tuple[str, Any]]: ("generator", None), ("latents", None), ("num_images_per_prompt", 1), + ("device", None), + ("dtype", None), ] - - @property - def required_components(self) -> List[str]: - return ["scheduler"] - + @property def intermediates_inputs(self) -> List[str]: return ["batch_size"] - + @property def intermediates_outputs(self) -> List[str]: return ["latents"] @@ -1506,12 +1550,19 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin height = state.get_input("height") width = state.get_input("width") generator = state.get_input("generator") + device = state.get_input("device") + dtype = state.get_input("dtype") batch_size = state.get_intermediate("batch_size") prompt_embeds = state.get_intermediate("prompt_embeds", None) - dtype = prompt_embeds.dtype if prompt_embeds is not None else pipeline.dtype - device = pipeline._execution_device + if dtype is None and prompt_embeds is not None: + dtype = prompt_embeds.dtype + elif dtype is None: + dtype = pipeline.vae.dtype + + if device is None: + device = pipeline._execution_device height = height or pipeline.default_sample_size * pipeline.vae_scale_factor width = width or pipeline.default_sample_size * pipeline.vae_scale_factor @@ -1538,6 +1589,8 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin class PrepareAdditionalConditioningStep(PipelineBlock): + required_components = ["unet"] + @property def inputs(self) -> List[Tuple[str, Any]]: return [ @@ -1550,18 +1603,14 @@ def inputs(self) -> List[Tuple[str, Any]]: ("num_images_per_prompt", 1), ("guidance_scale", 5.0), ] - + @property def intermediates_inputs(self) -> List[str]: - return ["latents"] - + return ["latents", "batch_size", "pooled_prompt_embeds"] + @property def intermediates_outputs(self) -> List[str]: return ["add_time_ids", "negative_add_time_ids", "timestep_cond"] - - @property - def required_components(self) -> List[str]: - return ["unet"] def __init__(self, unet=None): super().__init__(unet=unet) @@ -1635,10 +1684,12 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin class Image2ImagePrepareAdditionalConditioningStep(PipelineBlock): + required_components = ["unet"] + @property def inputs(self) -> List[Tuple[str, Any]]: return [ - ("original_size", None), + ("original_sizife", None), ("target_size", None), ("negative_original_size", None), ("negative_target_size", None), @@ -1649,18 +1700,14 @@ def inputs(self) -> List[Tuple[str, Any]]: ("aesthetic_score", 6.0), ("negative_aesthetic_score", 2.0), ] - + @property def intermediates_inputs(self) -> List[str]: return ["latents"] - + @property def intermediates_outputs(self) -> List[str]: return ["add_time_ids", "negative_add_time_ids", "timestep_cond"] - - @property - def required_components(self) -> List[str]: - return ["unet"] def __init__(self, unet=None, requires_aesthetics_score=False): super().__init__(unet=unet, requires_aesthetics_score=requires_aesthetics_score) @@ -1735,12 +1782,14 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin class PrepareGuidance(PipelineBlock): + required_auxiliaries = ["guider"] + @property def inputs(self) -> List[Tuple[str, Any]]: return [ ("guidance_scale", 5.0), ] - + @property def intermediates_inputs(self) -> List[str]: return [ @@ -1751,17 +1800,14 @@ def intermediates_inputs(self) -> List[str]: "pooled_prompt_embeds", "negative_pooled_prompt_embeds", ] - + @property def intermediates_outputs(self) -> List[str]: return ["add_text_embeds", "add_time_ids", "prompt_embeds"] - - @property - def required_auxiliaries(self) -> List[str]: - return ["guider"] - def __init__(self): - guider = CFGGuider() + def __init__(self, guider=None): + if guider is None: + guider = CFGGuider() super().__init__(guider=guider) @torch.no_grad() @@ -1797,6 +1843,9 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin class DenoiseStep(PipelineBlock): + required_components = ["unet", "scheduler"] + required_auxiliaries = ["guider"] + @property def inputs(self) -> List[Tuple[str, Any]]: return [ @@ -1806,7 +1855,7 @@ def inputs(self) -> List[Tuple[str, Any]]: ("generator", None), ("eta", 0.0), ] - + @property def intermediates_inputs(self) -> List[str]: return [ @@ -1817,18 +1866,16 @@ def intermediates_inputs(self) -> List[str]: "add_time_ids", "timestep_cond", "prompt_embeds", - ] - + ] + @property def intermediates_outputs(self) -> List[str]: return ["latents"] - - @property - def required_components(self) -> List[str]: - return ["unet"] - def __init__(self, unet=None): - super().__init__(unet=unet) + def __init__(self, unet=None, scheduler=None, guider=None): + if guider is None: + guider = CFGGuider() + super().__init__(unet=unet, scheduler=scheduler, guider=guider) @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: @@ -1892,28 +1939,28 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class DecodeLatentsStep(PipelineBlock): + optional_components = ["vae"] + required_auxiliaries = ["image_processor"] + @property def inputs(self) -> List[Tuple[str, Any]]: return [ ("output_type", "pil"), ("return_dict", True), ] - + @property def intermediates_inputs(self) -> List[str]: return ["latents"] - - @property - def optional_components(self) -> List[str]: - return ["vae"] - + @property - def required_auxiliaries(self) -> List[str]: - return ["image_processor"] + def intermediates_outputs(self) -> List[str]: + return ["images"] - def __init__(self, vae=None, vae_scale_factor=8): - image_processor = VaeImageProcessor(vae_scale_factor=8) - super().__init__(vae=vae, vae_scale_factor=vae_scale_factor, image_processor=image_processor) + def __init__(self, vae=None, image_processor=None, vae_scale_factor=8): + if image_processor is None: + image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) + super().__init__(vae=vae, image_processor=image_processor, vae_scale_factor=vae_scale_factor) @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: @@ -1973,7 +2020,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: output = StableDiffusionXLPipelineOutput(images=image) state.add_intermediate("images", image) - state.add_output(output) + state.add_output("images", output) return pipeline, state @@ -2047,8 +2094,13 @@ def add_blocks(self, pipeline_blocks: Union[PipelineBlock, List[PipelineBlock]]) f"Cannot add block {block.__class__.__name__}: Required auxiliary {required_auxiliary} not found in pipeline" ) - def run_pipeline(self, return_pipeline_state=False, **kwargs): - state = PipelineState() + def run_blocks(self, state: PipelineState = None, **kwargs): + """ + Run one or more blocks in sequence, optionally you can pass a previous pipeline state. + """ + if state is None: + state = PipelineState() + pipeline = self.pipeline # Make a copy of the input kwargs @@ -2056,6 +2108,11 @@ def run_pipeline(self, return_pipeline_state=False, **kwargs): default_params = self.default_call_parameters + # user can pass the intermediate of the first block + for name in self.pipeline_blocks[0].intermediates_inputs: + if name in input_params: + state.add_intermediate(name, input_params.pop(name)) + # Add inputs to state, using defaults if not provided for name, default in default_params.items(): if name in input_params: @@ -2072,11 +2129,34 @@ def run_pipeline(self, return_pipeline_state=False, **kwargs): for block in self.pipeline_blocks: pipeline, state = block(pipeline, state) - if return_pipeline_state: - return state - else: - return state.outputs + return state + def run_pipeline(self, **kwargs): + state = PipelineState() + pipeline = self.pipeline + + # Make a copy of the input kwargs + input_params = kwargs.copy() + + default_params = self.default_call_parameters + + # Add inputs to state, using defaults if not provided + for name, default in default_params.items(): + if name in input_params: + state.add_input(name, input_params.pop(name)) + else: + state.add_input(name, default) + + # Warn about unexpected inputs + if len(input_params) > 0: + logger.warning(f"Unexpected input '{input_params.keys()}' provided. This input will be ignored.") + + # Run the pipeline + with torch.no_grad(): + for block in self.pipeline_blocks: + pipeline, state = block(pipeline, state) + + return state.get_output("images") @property def default_call_parameters(self) -> Dict[str, Any]: @@ -2096,21 +2176,21 @@ def __repr__(self): output += "----------------\n" for i, block in enumerate(self.pipeline_blocks, 1): output += f"{i}. {block.__class__.__name__}\n" - + intermediates_str = "" - if hasattr(block, 'intermediates_inputs'): + if hasattr(block, "intermediates_inputs"): intermediates_str += f"{', '.join(block.intermediates_inputs)}" - - if hasattr(block, 'intermediates_outputs'): + + if hasattr(block, "intermediates_outputs"): if intermediates_str: intermediates_str += " -> " else: intermediates_str += "-> " intermediates_str += f"{', '.join(block.intermediates_outputs)}" - + if intermediates_str: output += f" {intermediates_str}\n" - + output += "\n" output += "\n" @@ -2128,5 +2208,13 @@ def __repr__(self): for name, default in params.items(): output += f"{name}: {default!r}\n" + # Add a section for required call parameters: + # intermediate inputs for the first block + output += "\nRequired Call Parameters:\n" + output += "--------------------------\n" + for name in self.pipeline_blocks[0].intermediates_inputs: + output += f"{name}: \n" + params[name] = "" + output += "\nNote: These are the default values. Actual values may be different when running the pipeline." return output From af9572d759c9840888065c424a63f806300b27d9 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 19 Oct 2024 12:36:12 +0200 Subject: [PATCH 006/170] controlnet --- .../pipelines/custom_pipeline_builder.py | 99 ++++++++++++++++++- 1 file changed, 98 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/custom_pipeline_builder.py b/src/diffusers/pipelines/custom_pipeline_builder.py index fa4a6bb2be8e..fc082958b883 100644 --- a/src/diffusers/pipelines/custom_pipeline_builder.py +++ b/src/diffusers/pipelines/custom_pipeline_builder.py @@ -25,7 +25,7 @@ from ..configuration_utils import ConfigMixin from ..image_processor import VaeImageProcessor from ..loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin -from ..models import ImageProjection +from ..models import ImageProjection, ControlNetModel from ..models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor from ..models.lora import adjust_lora_scale_text_encoder from ..utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers @@ -1160,6 +1160,103 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: return pipeline, state +from ..utils import is_compiled_module +from ..models import MultiControlNetModel, ControlNetModel +class ControlNetStep(PipelineBlock): + + def __init__(self, controlnet: ControlNetModel): + super().__init__(controlnet=controlnet) + + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + + control_guide_start = state.get_input("control_guide_start") + control_guide_end = state.get_input("control_guide_end") + controlnet_conditioning_scale = state.get_input("controlnet_conditioning_scale") + guess_mode = state.get_input("guess_mode") + num_images_per_prompt = state.get_input("num_images_per_prompt") + guidance_scale = state.get_input("guidance_scale") + + batch_size = state.get_intermediate("batch_size") + timesteps = state.get_intermediate("timesteps") + + do_classifier_free_guidance = guidance_scale > 1.0 + device = pipeline._execution_device + + + controlnet = pipeline.controlnet._orig_mod if is_compiled_module(pipeline.controlnet) else pipeline.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + image = pipeline.prepare_controlnet_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = image.shape[-2:] + elif isinstance(controlnet, MultiControlNetModel): + images = [] + + for image_ in image: + image_ = pipeline.prepare_controlnet_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + images.append(image_) + + image = images + height, width = image[0].shape[-2:] + else: + assert False + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + return pipeline, state + + class TextEncoderStep(PipelineBlock): optional_components = ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"] From 2b6dcbfa1db5fef930de63d96d8c15598089b936 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 20 Oct 2024 19:23:37 +0200 Subject: [PATCH 007/170] fix controlnet --- .../pipelines/custom_pipeline_builder.py | 301 ++++++++++++++---- 1 file changed, 247 insertions(+), 54 deletions(-) diff --git a/src/diffusers/pipelines/custom_pipeline_builder.py b/src/diffusers/pipelines/custom_pipeline_builder.py index fc082958b883..bbce5ac23b29 100644 --- a/src/diffusers/pipelines/custom_pipeline_builder.py +++ b/src/diffusers/pipelines/custom_pipeline_builder.py @@ -29,13 +29,14 @@ from ..models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor from ..models.lora import adjust_lora_scale_text_encoder from ..utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers -from ..utils.torch_utils import randn_tensor +from ..utils.torch_utils import randn_tensor, is_compiled_module from .pipeline_loading_utils import _fetch_class_library_tuple from .pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .stable_diffusion_xl import ( StableDiffusionXLPipeline, StableDiffusionXLPipelineOutput, ) +from .controlnet.multicontrolnet import MultiControlNetModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -78,11 +79,6 @@ def __init__(self): self.register_to_config() self.builder = None - def __repr__(self): - if self.builder: - return repr(self.builder) - return "CustomPipeline (not fully initialized)" - # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline.register_modules def register_modules(self, **kwargs): for name, module in kwargs.items(): @@ -166,12 +162,6 @@ def components(self) -> Dict[str, Any]: for block in self.builder.pipeline_blocks: components.update(block.components) - # Check if all items in config that are also in any block's components are included - for key in self.config.keys(): - if any(key in block.components for block in self.builder.pipeline_blocks): - if key not in components: - components[key] = getattr(self, key, None) - return components # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline.progress_bar @@ -1068,7 +1058,7 @@ def __init__(self, **kwargs): self.configs[key] = value @classmethod - def from_pipe(cls, pipe: DiffusionPipeline): + def from_pipe(cls, pipe: DiffusionPipeline, **kwargs): """ Create a PipelineBlock instance from a diffusion pipeline object. @@ -1078,34 +1068,49 @@ def from_pipe(cls, pipe: DiffusionPipeline): Returns: PipelineBlock: An instance initialized with the pipeline's components and configurations. """ - kwargs = {} - - # Add components + kwargs = kwargs.copy() + # add components + expected_components = set(cls.required_components + cls.optional_components) + # - components that are passed in kwargs + components_to_add = {component_name: kwargs.pop(component_name) for component_name in expected_components if component_name in kwargs} + # - components that are in the pipeline for component_name, component in pipe.components.items(): - if component_name in cls.required_components or component_name in cls.optional_components: - kwargs[component_name] = component + if component_name in expected_components and component_name not in components_to_add: + components_to_add[component_name] = component - # Add config items that are in the __init__ signature - init_params = inspect.signature(cls.__init__).parameters - for config_name in pipe.config.keys(): - if config_name in init_params and config_name not in kwargs: - kwargs[config_name] = pipe.config[config_name] - # Check for required auxiliaries + # add auxiliaries + # - auxiliaries that are passed in kwargs + auxiliaries_to_add = {k: kwargs.pop(k) for k in cls.required_auxiliaries if k in kwargs} + # - auxiliaries that are in the pipeline for aux_name in cls.required_auxiliaries: - if hasattr(pipe, aux_name): - kwargs[aux_name] = getattr(pipe, aux_name) + if hasattr(pipe, aux_name) and aux_name not in auxiliaries_to_add: + auxiliaries_to_add[aux_name] = getattr(pipe, aux_name) + block_kwargs = {**components_to_add, **auxiliaries_to_add} - # Add any remaining relevant attributes + # add pipeline configs + init_params = inspect.signature(cls.__init__).parameters + # modules info are also registered in the config as tuples, e.g. {'tokenizer': ('transformers', 'CLIPTokenizer')} + # we need to exclude them for block_kwargs otherwise it will override the actual module + expected_configs = {k for k in pipe.config.keys() if k in init_params and k not in expected_components and k not in cls.required_auxiliaries} + + for config_name in expected_configs: + if config_name not in block_kwargs: + if config_name in kwargs: + # - configs that are passed in kwargs + block_kwargs[config_name] = kwargs.pop(config_name) + else: + # - configs that are in the pipeline + block_kwargs[config_name] = pipe.config[config_name] + + + # Add any remaining relevant pipeline attributes for attr_name in dir(pipe): - if ( - not attr_name.startswith("_") - and attr_name not in kwargs - and attr_name not in ["components", "config"] + if (attr_name not in block_kwargs and attr_name in init_params ): - kwargs[attr_name] = getattr(pipe, attr_name) + block_kwargs[attr_name] = getattr(pipe, attr_name) - return cls(**kwargs) + return cls(**block_kwargs) def __call__(self, pipeline, state: PipelineState) -> PipelineState: raise NotImplementedError("__call__ method must be implemented in subclasses") @@ -1160,21 +1165,50 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: return pipeline, state -from ..utils import is_compiled_module -from ..models import MultiControlNetModel, ControlNetModel class ControlNetStep(PipelineBlock): - def __init__(self, controlnet: ControlNetModel): - super().__init__(controlnet=controlnet) + required_components = ["controlnet"] + required_auxiliaries = ["control_image_processor"] - def __call__(self, pipeline, state: PipelineState) -> PipelineState: + def __init__(self, controlnet: ControlNetModel, control_image_processor=None, vae_scale_factor: float = 8.0): + if control_image_processor is None: + control_image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor, do_normalize=False, do_convert_rgb=True) + super().__init__(controlnet=controlnet, control_image_processor=control_image_processor, vae_scale_factor=vae_scale_factor) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + ("control_image", None), + ("control_guidance_start", 0.0), + ("control_guidance_end", 1.0), + ("controlnet_conditioning_scale", 1.0), + ("guess_mode", False), + ("num_images_per_prompt", 1), + ("guidance_scale", 5.0), + ("width", None), + ("height", None), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return ["batch_size", "timesteps"] + + @property + def intermediates_outputs(self) -> List[str]: + return ["controlnet_keep", "controlnet_image", "controlnet_conditioning_scale", "guess_mode"] - control_guide_start = state.get_input("control_guide_start") - control_guide_end = state.get_input("control_guide_end") + + @torch.no_grad() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + control_image = state.get_input("control_image") + control_guidance_start = state.get_input("control_guidance_start") + control_guidance_end = state.get_input("control_guidance_end") controlnet_conditioning_scale = state.get_input("controlnet_conditioning_scale") guess_mode = state.get_input("guess_mode") num_images_per_prompt = state.get_input("num_images_per_prompt") guidance_scale = state.get_input("guidance_scale") + width = state.get_input("width") + height = state.get_input("height") batch_size = state.get_intermediate("batch_size") timesteps = state.get_intermediate("timesteps") @@ -1210,8 +1244,8 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: # 4. Prepare image if isinstance(controlnet, ControlNetModel): - image = pipeline.prepare_controlnet_image( - image=image, + control_image = pipeline.prepare_control_image( + image=control_image, width=width, height=height, batch_size=batch_size * num_images_per_prompt, @@ -1221,13 +1255,12 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: do_classifier_free_guidance=do_classifier_free_guidance, guess_mode=guess_mode, ) - height, width = image.shape[-2:] elif isinstance(controlnet, MultiControlNetModel): - images = [] + control_images = [] - for image_ in image: - image_ = pipeline.prepare_controlnet_image( - image=image_, + for control_image_ in control_image: + control_image = pipeline.prepare_control_image( + image=control_image_, width=width, height=height, batch_size=batch_size * num_images_per_prompt, @@ -1238,10 +1271,9 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: guess_mode=guess_mode, ) - images.append(image_) + control_images.append(control_image) - image = images - height, width = image[0].shape[-2:] + control_image = control_images else: assert False @@ -1253,6 +1285,13 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: for s, e in zip(control_guidance_start, control_guidance_end) ] controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + + state.add_intermediate("controlnet_keep", controlnet_keep) + state.add_intermediate("control_image", control_image) + state.add_intermediate("controlnet_conditioning_scale", controlnet_conditioning_scale) + state.add_intermediate("guess_mode", guess_mode) + return pipeline, state @@ -2035,6 +2074,155 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: return pipeline, state +class ControlNetDenoiseStep(PipelineBlock): + required_components = ["unet", "controlnet", "scheduler"] + required_auxiliaries = ["guider"] + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + ("guidance_scale", 5.0), + ("guidance_rescale", 0.0), + ("cross_attention_kwargs", None), + ("generator", None), + ("eta", 0.0), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + "latents", + "timesteps", + "num_inference_steps", + "add_text_embeds", + "add_time_ids", + "timestep_cond", + "prompt_embeds", + "guess_mode", + "controlnet_conditioning_scale", + "controlnet_keep", + "control_image", + ] + + @property + def intermediates_outputs(self) -> List[str]: + return ["latents"] + + def __init__(self, unet=None, controlnet=None, scheduler=None, guider=None): + if guider is None: + guider = CFGGuider() + super().__init__(unet=unet, controlnet=controlnet, scheduler=scheduler, guider=guider) + + @torch.no_grad() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + guidance_scale = state.get_input("guidance_scale") + guidance_rescale = state.get_input("guidance_rescale") + cross_attention_kwargs = state.get_input("cross_attention_kwargs") + generator = state.get_input("generator") + eta = state.get_input("eta") + + latents = state.get_intermediate("latents") + timesteps = state.get_intermediate("timesteps") + num_inference_steps = state.get_intermediate("num_inference_steps") + + add_text_embeds = state.get_intermediate("add_text_embeds") + add_time_ids = state.get_intermediate("add_time_ids") + timestep_cond = state.get_intermediate("timestep_cond") + prompt_embeds = state.get_intermediate("prompt_embeds") + guess_mode = state.get_intermediate("guess_mode") + controlnet_conditioning_scale = state.get_intermediate("controlnet_conditioning_scale") + controlnet_keep = state.get_intermediate("controlnet_keep") + control_image = state.get_intermediate("control_image") + + do_classifier_free_guidance = guidance_scale > 1.0 + + # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta) + num_warmup_steps = max(len(timesteps) - num_inference_steps * pipeline.scheduler.order, 0) + + with pipeline.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = pipeline.guider.prepare_inputs_for_cfg( + latents, latents, do_classifier_free_guidance + ) + latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t) + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + # controlnet(s) inference + if guess_mode and do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = pipeline.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + controlnet_added_cond_kwargs = { + "text_embeds": add_text_embeds.chunk(2)[1], + "time_ids": add_time_ids.chunk(2)[1], + } + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = added_cond_kwargs + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + down_block_res_samples, mid_block_res_sample = pipeline.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=control_image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + ) + + if guess_mode and do_classifier_free_guidance: + # Inferred ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + noise_pred = pipeline.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + return_dict=False, + )[0] + # perform guidance + noise_pred = pipeline.guider.apply_guidance( + noise_pred, guidance_scale, do_classifier_free_guidance, guidance_rescale + ) + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = pipeline.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): + progress_bar.update() + + state.add_intermediate("latents", latents) + + return pipeline, state + + + + class DecodeLatentsStep(PipelineBlock): optional_components = ["vae"] required_auxiliaries = ["image_processor"] @@ -2155,7 +2343,7 @@ def __init__(self, pipeline_class: str): raise ValueError(f"Pipeline class {pipeline_class} not supported") self.pipeline_blocks = [] self.pipeline.builder = self - + def add_blocks(self, pipeline_blocks: Union[PipelineBlock, List[PipelineBlock]]): if not isinstance(pipeline_blocks, list): pipeline_blocks = [pipeline_blocks] @@ -2224,9 +2412,14 @@ def run_blocks(self, state: PipelineState = None, **kwargs): # Run the pipeline with torch.no_grad(): for block in self.pipeline_blocks: - pipeline, state = block(pipeline, state) - - return state + try: + pipeline, state = block(pipeline, state) + except Exception as e: + error_msg = f"Error in block: ({block.__class__.__name__}):\n" + logger.error(error_msg) + raise + + return state def run_pipeline(self, **kwargs): state = PipelineState() From 70272b1108c6bf70409c8490af0b56e4daef780c Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 20 Oct 2024 19:45:00 +0200 Subject: [PATCH 008/170] combine controlnetstep into contronetdesnoisestep --- .../pipelines/custom_pipeline_builder.py | 281 ++++++++---------- 1 file changed, 127 insertions(+), 154 deletions(-) diff --git a/src/diffusers/pipelines/custom_pipeline_builder.py b/src/diffusers/pipelines/custom_pipeline_builder.py index bbce5ac23b29..4b6d7587ef57 100644 --- a/src/diffusers/pipelines/custom_pipeline_builder.py +++ b/src/diffusers/pipelines/custom_pipeline_builder.py @@ -25,18 +25,18 @@ from ..configuration_utils import ConfigMixin from ..image_processor import VaeImageProcessor from ..loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin -from ..models import ImageProjection, ControlNetModel +from ..models import ControlNetModel, ImageProjection from ..models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor from ..models.lora import adjust_lora_scale_text_encoder from ..utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers -from ..utils.torch_utils import randn_tensor, is_compiled_module +from ..utils.torch_utils import is_compiled_module, randn_tensor +from .controlnet.multicontrolnet import MultiControlNetModel from .pipeline_loading_utils import _fetch_class_library_tuple from .pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .stable_diffusion_xl import ( StableDiffusionXLPipeline, StableDiffusionXLPipelineOutput, ) -from .controlnet.multicontrolnet import MultiControlNetModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -1072,7 +1072,11 @@ def from_pipe(cls, pipe: DiffusionPipeline, **kwargs): # add components expected_components = set(cls.required_components + cls.optional_components) # - components that are passed in kwargs - components_to_add = {component_name: kwargs.pop(component_name) for component_name in expected_components if component_name in kwargs} + components_to_add = { + component_name: kwargs.pop(component_name) + for component_name in expected_components + if component_name in kwargs + } # - components that are in the pipeline for component_name, component in pipe.components.items(): if component_name in expected_components and component_name not in components_to_add: @@ -1091,7 +1095,11 @@ def from_pipe(cls, pipe: DiffusionPipeline, **kwargs): init_params = inspect.signature(cls.__init__).parameters # modules info are also registered in the config as tuples, e.g. {'tokenizer': ('transformers', 'CLIPTokenizer')} # we need to exclude them for block_kwargs otherwise it will override the actual module - expected_configs = {k for k in pipe.config.keys() if k in init_params and k not in expected_components and k not in cls.required_auxiliaries} + expected_configs = { + k + for k in pipe.config.keys() + if k in init_params and k not in expected_components and k not in cls.required_auxiliaries + } for config_name in expected_configs: if config_name not in block_kwargs: @@ -1101,13 +1109,10 @@ def from_pipe(cls, pipe: DiffusionPipeline, **kwargs): else: # - configs that are in the pipeline block_kwargs[config_name] = pipe.config[config_name] - # Add any remaining relevant pipeline attributes for attr_name in dir(pipe): - if (attr_name not in block_kwargs - and attr_name in init_params - ): + if attr_name not in block_kwargs and attr_name in init_params: block_kwargs[attr_name] = getattr(pipe, attr_name) return cls(**block_kwargs) @@ -1165,137 +1170,6 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: return pipeline, state -class ControlNetStep(PipelineBlock): - - required_components = ["controlnet"] - required_auxiliaries = ["control_image_processor"] - - def __init__(self, controlnet: ControlNetModel, control_image_processor=None, vae_scale_factor: float = 8.0): - if control_image_processor is None: - control_image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor, do_normalize=False, do_convert_rgb=True) - super().__init__(controlnet=controlnet, control_image_processor=control_image_processor, vae_scale_factor=vae_scale_factor) - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - ("control_image", None), - ("control_guidance_start", 0.0), - ("control_guidance_end", 1.0), - ("controlnet_conditioning_scale", 1.0), - ("guess_mode", False), - ("num_images_per_prompt", 1), - ("guidance_scale", 5.0), - ("width", None), - ("height", None), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return ["batch_size", "timesteps"] - - @property - def intermediates_outputs(self) -> List[str]: - return ["controlnet_keep", "controlnet_image", "controlnet_conditioning_scale", "guess_mode"] - - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - control_image = state.get_input("control_image") - control_guidance_start = state.get_input("control_guidance_start") - control_guidance_end = state.get_input("control_guidance_end") - controlnet_conditioning_scale = state.get_input("controlnet_conditioning_scale") - guess_mode = state.get_input("guess_mode") - num_images_per_prompt = state.get_input("num_images_per_prompt") - guidance_scale = state.get_input("guidance_scale") - width = state.get_input("width") - height = state.get_input("height") - - batch_size = state.get_intermediate("batch_size") - timesteps = state.get_intermediate("timesteps") - - do_classifier_free_guidance = guidance_scale > 1.0 - device = pipeline._execution_device - - - controlnet = pipeline.controlnet._orig_mod if is_compiled_module(pipeline.controlnet) else pipeline.controlnet - - # align format for control guidance - if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): - control_guidance_start = len(control_guidance_end) * [control_guidance_start] - elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): - control_guidance_end = len(control_guidance_start) * [control_guidance_end] - elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): - mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 - control_guidance_start, control_guidance_end = ( - mult * [control_guidance_start], - mult * [control_guidance_end], - ) - - if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): - controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) - - global_pool_conditions = ( - controlnet.config.global_pool_conditions - if isinstance(controlnet, ControlNetModel) - else controlnet.nets[0].config.global_pool_conditions - ) - guess_mode = guess_mode or global_pool_conditions - - - # 4. Prepare image - if isinstance(controlnet, ControlNetModel): - control_image = pipeline.prepare_control_image( - image=control_image, - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, - guess_mode=guess_mode, - ) - elif isinstance(controlnet, MultiControlNetModel): - control_images = [] - - for control_image_ in control_image: - control_image = pipeline.prepare_control_image( - image=control_image_, - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, - guess_mode=guess_mode, - ) - - control_images.append(control_image) - - control_image = control_images - else: - assert False - - # 7.1 Create tensor stating which controlnets to keep - controlnet_keep = [] - for i in range(len(timesteps)): - keeps = [ - 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) - for s, e in zip(control_guidance_start, control_guidance_end) - ] - controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) - - - state.add_intermediate("controlnet_keep", controlnet_keep) - state.add_intermediate("control_image", control_image) - state.add_intermediate("controlnet_conditioning_scale", controlnet_conditioning_scale) - state.add_intermediate("guess_mode", guess_mode) - - - return pipeline, state - - class TextEncoderStep(PipelineBlock): optional_components = ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"] @@ -2076,11 +1950,17 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class ControlNetDenoiseStep(PipelineBlock): required_components = ["unet", "controlnet", "scheduler"] - required_auxiliaries = ["guider"] + required_auxiliaries = ["guider", "control_image_processor"] @property def inputs(self) -> List[Tuple[str, Any]]: return [ + ("control_image", None), + ("control_guidance_start", 0.0), + ("control_guidance_end", 1.0), + ("controlnet_conditioning_scale", 1.0), + ("guess_mode", False), + ("num_images_per_prompt", 1), ("guidance_scale", 5.0), ("guidance_rescale", 0.0), ("cross_attention_kwargs", None), @@ -2092,6 +1972,7 @@ def inputs(self) -> List[Tuple[str, Any]]: def intermediates_inputs(self) -> List[str]: return [ "latents", + "batch_size", "timesteps", "num_inference_steps", "add_text_embeds", @@ -2108,10 +1989,27 @@ def intermediates_inputs(self) -> List[str]: def intermediates_outputs(self) -> List[str]: return ["latents"] - def __init__(self, unet=None, controlnet=None, scheduler=None, guider=None): + def __init__( + self, + unet=None, + controlnet=None, + scheduler=None, + guider=None, + control_image_processor=None, + vae_scale_factor=8.0, + ): if guider is None: guider = CFGGuider() - super().__init__(unet=unet, controlnet=controlnet, scheduler=scheduler, guider=guider) + if control_image_processor is None: + control_image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) + super().__init__( + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + guider=guider, + control_image_processor=control_image_processor, + vae_scale_factor=vae_scale_factor, + ) @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: @@ -2120,7 +2018,14 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: cross_attention_kwargs = state.get_input("cross_attention_kwargs") generator = state.get_input("generator") eta = state.get_input("eta") + control_image = state.get_input("control_image") + control_guidance_start = state.get_input("control_guidance_start") + control_guidance_end = state.get_input("control_guidance_end") + controlnet_conditioning_scale = state.get_input("controlnet_conditioning_scale") + guess_mode = state.get_input("guess_mode") + num_images_per_prompt = state.get_input("num_images_per_prompt") + batch_size = state.get_intermediate("batch_size") latents = state.get_intermediate("latents") timesteps = state.get_intermediate("timesteps") num_inference_steps = state.get_intermediate("num_inference_steps") @@ -2129,12 +2034,82 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: add_time_ids = state.get_intermediate("add_time_ids") timestep_cond = state.get_intermediate("timestep_cond") prompt_embeds = state.get_intermediate("prompt_embeds") - guess_mode = state.get_intermediate("guess_mode") - controlnet_conditioning_scale = state.get_intermediate("controlnet_conditioning_scale") - controlnet_keep = state.get_intermediate("controlnet_keep") - control_image = state.get_intermediate("control_image") do_classifier_free_guidance = guidance_scale > 1.0 + device = pipeline._execution_device + + height, width = latents.shape[-2:] + height = height * pipeline.vae_scale_factor + width = width * pipeline.vae_scale_factor + + # prepare controlnet inputs + controlnet = pipeline.controlnet._orig_mod if is_compiled_module(pipeline.controlnet) else pipeline.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + control_image = pipeline.prepare_control_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + elif isinstance(controlnet, MultiControlNetModel): + control_images = [] + + for control_image_ in control_image: + control_image = pipeline.prepare_control_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + control_images.append(control_image) + + control_image = control_images + else: + assert False + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta) @@ -2189,7 +2164,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: # add 0 to the unconditional batch to keep it unchanged. down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) - + noise_pred = pipeline.unet( latent_model_input, t, @@ -2221,8 +2196,6 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: return pipeline, state - - class DecodeLatentsStep(PipelineBlock): optional_components = ["vae"] required_auxiliaries = ["image_processor"] @@ -2343,7 +2316,7 @@ def __init__(self, pipeline_class: str): raise ValueError(f"Pipeline class {pipeline_class} not supported") self.pipeline_blocks = [] self.pipeline.builder = self - + def add_blocks(self, pipeline_blocks: Union[PipelineBlock, List[PipelineBlock]]): if not isinstance(pipeline_blocks, list): pipeline_blocks = [pipeline_blocks] @@ -2414,12 +2387,12 @@ def run_blocks(self, state: PipelineState = None, **kwargs): for block in self.pipeline_blocks: try: pipeline, state = block(pipeline, state) - except Exception as e: + except Exception: error_msg = f"Error in block: ({block.__class__.__name__}):\n" logger.error(error_msg) - raise + raise - return state + return state def run_pipeline(self, **kwargs): state = PipelineState() From 46ec1743a232efac634df71ed8c15396d8e5dda8 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 23 Oct 2024 21:42:40 +0200 Subject: [PATCH 009/170] refactor guider, remove prepareguidance step to be combinedd into denoisestep --- .../pipelines/custom_pipeline_builder.py | 318 +++++++++++------- 1 file changed, 200 insertions(+), 118 deletions(-) diff --git a/src/diffusers/pipelines/custom_pipeline_builder.py b/src/diffusers/pipelines/custom_pipeline_builder.py index 4b6d7587ef57..e8dacd483a98 100644 --- a/src/diffusers/pipelines/custom_pipeline_builder.py +++ b/src/diffusers/pipelines/custom_pipeline_builder.py @@ -282,21 +282,43 @@ class CFGGuider: This class is used to guide the pipeline with CFG (Classifier-Free Guidance). """ - def prepare_inputs_for_cfg( - self, negative_cond_input: torch.Tensor, cond_input: torch.Tensor, do_classifier_free_guidance: bool - ) -> torch.Tensor: - if do_classifier_free_guidance: - return torch.cat([negative_cond_input, cond_input], dim=0) - else: - return cond_input + def prepare_input( + self, + negative_cond_input: Union[torch.Tensor, List[torch.Tensor]], + cond_input: Union[torch.Tensor, List[torch.Tensor]], + do_classifier_free_guidance: bool, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """ + Prepare the input for CFG. - def prepare_inputs(self, cfg_input_mapping: Dict[str, Any], do_classifier_free_guidance: bool) -> Dict[str, Any]: - prepared_inputs = {} - for cfg_input_name, (negative_cond_input, cond_input) in cfg_input_mapping.items(): - prepared_inputs[cfg_input_name] = self.prepare_inputs_for_cfg( - negative_cond_input, cond_input, do_classifier_free_guidance - ) - return prepared_inputs + Args: + negative_cond_input (Union[torch.Tensor, List[torch.Tensor]]): The negative conditional input. It can be a + single tensor or a list of tensors. It must have the same length as `cond_input`. + cond_input (Union[torch.Tensor, List[torch.Tensor]]): + The conditional input. It can be a single tensor or a + list of tensors. It must have the same length as `negative_cond_input`. + do_classifier_free_guidance (bool): Whether to perform classifier-free guidance. + + Returns: + Union[torch.Tensor, List[torch.Tensor]]: The prepared input. + """ + if isinstance(negative_cond_input, list) and isinstance(cond_input, list): + if len(negative_cond_input) != len(cond_input): + raise ValueError("The length of negative_cond_input and cond_input must be the same.") + prepared_input = [] + for neg_cond, cond in zip(negative_cond_input, cond_input): + if do_classifier_free_guidance: + prepared_input.append(torch.cat([neg_cond, cond], dim=0)) + else: + prepared_input.append(cond) + return prepared_input + elif isinstance(negative_cond_input, torch.Tensor) and isinstance(cond_input, torch.Tensor): + if do_classifier_free_guidance: + return torch.cat([negative_cond_input, cond_input], dim=0) + else: + return cond_input + else: + raise ValueError(f"Unsupported input type: {type(negative_cond_input)} and {type(cond_input)}") def apply_guidance( self, @@ -442,7 +464,8 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state return image_embeds, uncond_image_embeds - # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image + # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image + # return image without apply any guidance def prepare_control_image( self, image, @@ -468,9 +491,6 @@ def prepare_control_image( image = image.to(device=device, dtype=dtype) - if do_classifier_free_guidance and not guess_mode: - image = torch.cat([image] * 2) - return image # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt @@ -1791,67 +1811,6 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin return pipeline, state -class PrepareGuidance(PipelineBlock): - required_auxiliaries = ["guider"] - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - ("guidance_scale", 5.0), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - "add_time_ids", - "negative_add_time_ids", - "prompt_embeds", - "negative_prompt_embeds", - "pooled_prompt_embeds", - "negative_pooled_prompt_embeds", - ] - - @property - def intermediates_outputs(self) -> List[str]: - return ["add_text_embeds", "add_time_ids", "prompt_embeds"] - - def __init__(self, guider=None): - if guider is None: - guider = CFGGuider() - super().__init__(guider=guider) - - @torch.no_grad() - def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - guidance_scale = state.get_input("guidance_scale") - - prompt_embeds = state.get_intermediate("prompt_embeds") - negative_prompt_embeds = state.get_intermediate("negative_prompt_embeds") - pooled_prompt_embeds = state.get_intermediate("pooled_prompt_embeds") - negative_pooled_prompt_embeds = state.get_intermediate("negative_pooled_prompt_embeds") - add_time_ids = state.get_intermediate("add_time_ids") - negative_add_time_ids = state.get_intermediate("negative_add_time_ids") - - do_classifier_free_guidance = guidance_scale > 1.0 - guider = pipeline.guider - - # Fetch all model inputs from pipeline_state - conditional_inputs = { - "prompt_embeds": (negative_prompt_embeds, prompt_embeds), - "add_time_ids": (negative_add_time_ids, add_time_ids), - "add_text_embeds": (negative_pooled_prompt_embeds, pooled_prompt_embeds), - } - - # Prepare inputs using the guider - prepared_conditional_inputs = guider.prepare_inputs(conditional_inputs, do_classifier_free_guidance) - - # Add prepared inputs back to the state - state.add_intermediate("add_text_embeds", prepared_conditional_inputs["add_text_embeds"]) - state.add_intermediate("add_time_ids", prepared_conditional_inputs["add_time_ids"]) - state.add_intermediate("prompt_embeds", prepared_conditional_inputs["prompt_embeds"]) - - return pipeline, state - - class DenoiseStep(PipelineBlock): required_components = ["unet", "scheduler"] required_auxiliaries = ["guider"] @@ -1872,10 +1831,13 @@ def intermediates_inputs(self) -> List[str]: "latents", "timesteps", "num_inference_steps", - "add_text_embeds", + "pooled_prompt_embeds", + "negative_pooled_prompt_embeds", "add_time_ids", + "negative_add_time_ids", "timestep_cond", "prompt_embeds", + "negative_prompt_embeds", ] @property @@ -1891,21 +1853,40 @@ def __init__(self, unet=None, scheduler=None, guider=None): def __call__(self, pipeline, state: PipelineState) -> PipelineState: guidance_scale = state.get_input("guidance_scale") guidance_rescale = state.get_input("guidance_rescale") + cross_attention_kwargs = state.get_input("cross_attention_kwargs") generator = state.get_input("generator") eta = state.get_input("eta") + prompt_embeds = state.get_intermediate("prompt_embeds") + negative_prompt_embeds = state.get_intermediate("negative_prompt_embeds") + pooled_prompt_embeds = state.get_intermediate("pooled_prompt_embeds") + negative_pooled_prompt_embeds = state.get_intermediate("negative_pooled_prompt_embeds") + add_time_ids = state.get_intermediate("add_time_ids") + negative_add_time_ids = state.get_intermediate("negative_add_time_ids") + + timestep_cond = state.get_intermediate("timestep_cond") latents = state.get_intermediate("latents") + timesteps = state.get_intermediate("timesteps") num_inference_steps = state.get_intermediate("num_inference_steps") - add_text_embeds = state.get_intermediate("add_text_embeds") - add_time_ids = state.get_intermediate("add_time_ids") - timestep_cond = state.get_intermediate("timestep_cond") - prompt_embeds = state.get_intermediate("prompt_embeds") - do_classifier_free_guidance = guidance_scale > 1.0 + # Prepare conditional inputs using the guider + prompt_embeds = pipeline.guider.prepare_input( + negative_prompt_embeds, prompt_embeds, do_classifier_free_guidance + ) + add_time_ids = pipeline.guider.prepare_input(negative_add_time_ids, add_time_ids, do_classifier_free_guidance) + pooled_prompt_embeds = pipeline.guider.prepare_input( + negative_pooled_prompt_embeds, pooled_prompt_embeds, do_classifier_free_guidance + ) + + added_cond_kwargs = { + "text_embeds": pooled_prompt_embeds, + "time_ids": add_time_ids, + } + # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta) num_warmup_steps = max(len(timesteps) - num_inference_steps * pipeline.scheduler.order, 0) @@ -1913,12 +1894,9 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: with pipeline.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = pipeline.guider.prepare_inputs_for_cfg( - latents, latents, do_classifier_free_guidance - ) + latent_model_input = pipeline.guider.prepare_input(latents, latents, do_classifier_free_guidance) latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual - added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} noise_pred = pipeline.unet( latent_model_input, t, @@ -1975,14 +1953,13 @@ def intermediates_inputs(self) -> List[str]: "batch_size", "timesteps", "num_inference_steps", - "add_text_embeds", + "prompt_embeds", + "negative_prompt_embeds", "add_time_ids", + "negative_add_time_ids", + "pooled_prompt_embeds", + "negative_pooled_prompt_embeds", "timestep_cond", - "prompt_embeds", - "guess_mode", - "controlnet_conditioning_scale", - "controlnet_keep", - "control_image", ] @property @@ -2018,22 +1995,27 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: cross_attention_kwargs = state.get_input("cross_attention_kwargs") generator = state.get_input("generator") eta = state.get_input("eta") + num_images_per_prompt = state.get_input("num_images_per_prompt") + # controlnet-specific inputs control_image = state.get_input("control_image") control_guidance_start = state.get_input("control_guidance_start") control_guidance_end = state.get_input("control_guidance_end") controlnet_conditioning_scale = state.get_input("controlnet_conditioning_scale") guess_mode = state.get_input("guess_mode") - num_images_per_prompt = state.get_input("num_images_per_prompt") batch_size = state.get_intermediate("batch_size") latents = state.get_intermediate("latents") timesteps = state.get_intermediate("timesteps") num_inference_steps = state.get_intermediate("num_inference_steps") - add_text_embeds = state.get_intermediate("add_text_embeds") + prompt_embeds = state.get_intermediate("prompt_embeds") + negative_prompt_embeds = state.get_intermediate("negative_prompt_embeds") + pooled_prompt_embeds = state.get_intermediate("pooled_prompt_embeds") + negative_pooled_prompt_embeds = state.get_intermediate("negative_pooled_prompt_embeds") add_time_ids = state.get_intermediate("add_time_ids") + negative_add_time_ids = state.get_intermediate("negative_add_time_ids") + timestep_cond = state.get_intermediate("timestep_cond") - prompt_embeds = state.get_intermediate("prompt_embeds") do_classifier_free_guidance = guidance_scale > 1.0 device = pipeline._execution_device @@ -2111,34 +2093,56 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: ] controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + # Prepare conditional inputs for unet using the guider + prompt_embeds = pipeline.guider.prepare_input( + negative_prompt_embeds, prompt_embeds, do_classifier_free_guidance + ) + add_time_ids = pipeline.guider.prepare_input(negative_add_time_ids, add_time_ids, do_classifier_free_guidance) + pooled_prompt_embeds = pipeline.guider.prepare_input( + negative_pooled_prompt_embeds, pooled_prompt_embeds, do_classifier_free_guidance + ) + + added_cond_kwargs = { + "text_embeds": pooled_prompt_embeds, + "time_ids": add_time_ids, + } + + # Prepare conditional inputs for controlnet using the guider + # common inputs: prompt_embeds, add_time_ids, pooled_prompt_embeds + controlnet_do_classifier_free_guidance = do_classifier_free_guidance and not guess_mode + if do_classifier_free_guidance and not controlnet_do_classifier_free_guidance: + # when `guess_mode` and `do_classifier_free_guidance` are both enabled, we apply guidance for unet but not for controlnet + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + controlnet_added_cond_kwargs = { + "text_embeds": pooled_prompt_embeds.chunk(2)[1], + "time_ids": add_time_ids.chunk(2)[1], + } + else: + # if `guess_mode` is not enabled or `do_classifier_free_guidance` is not enabled, these inputs should be the same as unet + controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = added_cond_kwargs + # controlnet-specific inputs: control_image + control_image = pipeline.guider.prepare_input( + control_image, control_image, controlnet_do_classifier_free_guidance + ) + # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta) num_warmup_steps = max(len(timesteps) - num_inference_steps * pipeline.scheduler.order, 0) with pipeline.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = pipeline.guider.prepare_inputs_for_cfg( - latents, latents, do_classifier_free_guidance - ) + # prepare latents for unet using the guider + latent_model_input = pipeline.guider.prepare_input(latents, latents, do_classifier_free_guidance) latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual - added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} - # controlnet(s) inference - if guess_mode and do_classifier_free_guidance: - # Infer ControlNet only for the conditional batch. + # prepare latents for controlnet using the guider + if do_classifier_free_guidance and not controlnet_do_classifier_free_guidance: control_model_input = latents control_model_input = pipeline.scheduler.scale_model_input(control_model_input, t) - controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] - controlnet_added_cond_kwargs = { - "text_embeds": add_text_embeds.chunk(2)[1], - "time_ids": add_time_ids.chunk(2)[1], - } else: control_model_input = latent_model_input - controlnet_prompt_embeds = prompt_embeds - controlnet_added_cond_kwargs = added_cond_kwargs if isinstance(controlnet_keep[i], list): cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] @@ -2158,10 +2162,9 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: return_dict=False, ) - if guess_mode and do_classifier_free_guidance: - # Inferred ControlNet only for the conditional batch. - # To apply the output of ControlNet to both the unconditional and conditional batches, - # add 0 to the unconditional batch to keep it unchanged. + if do_classifier_free_guidance and not controlnet_do_classifier_free_guidance: + # when we apply guidance for unet, but not for controlnet: + # add 0 to the unconditional batch down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) @@ -2283,6 +2286,80 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: return pipeline, state +class PAGGuider: + """ + This class is used to guide the pipeline with CFG (Classifier-Free Guidance). + """ + + def prepare_input( + self, + negative_cond_input: Union[torch.Tensor, List[torch.Tensor]], + cond_input: Union[torch.Tensor, List[torch.Tensor]], + do_classifier_free_guidance: bool, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """ + Prepare the input for CFG. + + Args: + negative_cond_input (Union[torch.Tensor, List[torch.Tensor]]): The negative conditional input. It can be a + single tensor or a list of tensors. It must have the same length as `cond_input`. + cond_input (Union[torch.Tensor, List[torch.Tensor]]): + The conditional input. It can be a single tensor or a + list of tensors. It must have the same length as `negative_cond_input`. + do_classifier_free_guidance (bool): Whether to perform classifier-free guidance. + + Returns: + Union[torch.Tensor, List[torch.Tensor]]: The prepared input. + """ + if isinstance(negative_cond_input, list) and isinstance(cond_input, list): + if len(negative_cond_input) != len(cond_input): + raise ValueError("The length of negative_cond_input and cond_input must be the same.") + + prepared_input = [] + for neg_cond, cond in zip(negative_cond_input, cond_input): + cond = torch.cat([cond] * 2, dim=0) + + if do_classifier_free_guidance: + prepared_input.append(torch.cat([neg_cond, cond], dim=0)) + else: + prepared_input.append(cond) + + return prepared_input + elif isinstance(negative_cond_input, torch.Tensor) and isinstance(cond_input, torch.Tensor): + cond_input = torch.cat([cond_input] * 2, dim=0) + + if do_classifier_free_guidance: + return torch.cat([negative_cond_input, cond_input], dim=0) + else: + return cond_input + else: + raise ValueError(f"Unsupported input type: {type(negative_cond_input)} and {type(cond_input)}") + + def apply_guidance( + self, + model_output: torch.Tensor, + pag_scale: float, + guidance_scale: float, + do_classifier_free_guidance: bool, + guidance_rescale: float = 0.0, + ) -> torch.Tensor: + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text, noise_pred_perturb = model_output.chunk(3) + noise_pred = ( + noise_pred_uncond + + guidance_scale * (noise_pred_text - noise_pred_uncond) + + pag_scale * (noise_pred_text - noise_pred_perturb) + ) + else: + noise_pred_text, noise_pred_perturb = model_output.chunk(2) + noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb) + + if guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + return noise_pred + + class PipelineBlockType(Enum): InputStep = 1 TextEncoderStep = 2 @@ -2417,7 +2494,12 @@ def run_pipeline(self, **kwargs): # Run the pipeline with torch.no_grad(): for block in self.pipeline_blocks: - pipeline, state = block(pipeline, state) + try: + pipeline, state = block(pipeline, state) + except Exception: + error_msg = f"Error in block: ({block.__class__.__name__}):\n" + logger.error(error_msg) + raise return state.get_output("images") From f1b3036ca1b6800956ae3167651d7eb37df72794 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 24 Oct 2024 00:14:59 +0200 Subject: [PATCH 010/170] update pag guider - draft --- .../pipelines/custom_pipeline_builder.py | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/src/diffusers/pipelines/custom_pipeline_builder.py b/src/diffusers/pipelines/custom_pipeline_builder.py index e8dacd483a98..8f4c830664a4 100644 --- a/src/diffusers/pipelines/custom_pipeline_builder.py +++ b/src/diffusers/pipelines/custom_pipeline_builder.py @@ -2286,11 +2286,54 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: return pipeline, state +from diffusers.models.attention_processor import AttentionProcessor +from diffusers.models.attention_processor import PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0 + class PAGGuider: """ This class is used to guide the pipeline with CFG (Classifier-Free Guidance). """ + def __init__(self, + pag_applied_layers: Union[str, List[str]], + pag_attn_processors: Tuple[AttentionProcessor, AttentionProcessor] = ( + PAGCFGIdentitySelfAttnProcessor2_0(), + PAGIdentitySelfAttnProcessor2_0(), + ), + ): + r""" + Set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid. + + Args: + pag_applied_layers (`str` or `List[str]`): + One or more strings identifying the layer names, or a simple regex for matching multiple layers, where + PAG is to be applied. A few ways of expected usage are as follows: + - Single layers specified as - "blocks.{layer_index}" + - Multiple layers as a list - ["blocks.{layers_index_1}", "blocks.{layer_index_2}", ...] + - Multiple layers as a block name - "mid" + - Multiple layers as regex - "blocks.({layer_index_1}|{layer_index_2})" + pag_attn_processors: + (`Tuple[AttentionProcessor, AttentionProcessor]`, defaults to `(PAGCFGIdentitySelfAttnProcessor2_0(), + PAGIdentitySelfAttnProcessor2_0())`): A tuple of two attention processors. The first attention + processor is for PAG with Classifier-free guidance enabled (conditional and unconditional). The second + attention processor is for PAG with CFG disabled (unconditional only). + """ + + if not isinstance(pag_applied_layers, list): + pag_applied_layers = [pag_applied_layers] + if pag_attn_processors is not None: + if not isinstance(pag_attn_processors, tuple) or len(pag_attn_processors) != 2: + raise ValueError("Expected a tuple of two attention processors") + + for i in range(len(pag_applied_layers)): + if not isinstance(pag_applied_layers[i], str): + raise ValueError( + f"Expected either a string or a list of string but got type {type(pag_applied_layers[i])}" + ) + + self.pag_applied_layers = pag_applied_layers + self._pag_attn_processors = pag_attn_processors + def prepare_input( self, negative_cond_input: Union[torch.Tensor, List[torch.Tensor]], From 540d3032504a22c07db4911a26ba512720afee2d Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 26 Oct 2024 21:17:06 +0200 Subject: [PATCH 011/170] refactor guider --- .../pipelines/custom_pipeline_builder.py | 184 +++++++++++++----- 1 file changed, 134 insertions(+), 50 deletions(-) diff --git a/src/diffusers/pipelines/custom_pipeline_builder.py b/src/diffusers/pipelines/custom_pipeline_builder.py index 8f4c830664a4..165f01473877 100644 --- a/src/diffusers/pipelines/custom_pipeline_builder.py +++ b/src/diffusers/pipelines/custom_pipeline_builder.py @@ -282,38 +282,89 @@ class CFGGuider: This class is used to guide the pipeline with CFG (Classifier-Free Guidance). """ + + def set_up_guider(self, guider_kwargs: Dict[str, Any]): + do_classifier_free_guidance = guider_kwargs.get("do_classifier_free_guidance", None) + if do_classifier_free_guidance is None: + raise ValueError("do_classifier_free_guidance is not provided in guider_kwargs") + guidance_scale = guider_kwargs.get("guidance_scale", None) + if guidance_scale is None: + raise ValueError("guidance_scale is not provided in guider_kwargs") + guidance_rescale = guider_kwargs.get("guidance_rescale", 0.0) + batch_size = guider_kwargs.get("batch_size", None) + if batch_size is None: + raise ValueError("batch_size is not provided in guider_kwargs") + self.guidance_scale = guidance_scale + self.guidance_rescale = guidance_rescale + self.do_classifier_free_guidance = do_classifier_free_guidance + self.batch_size = batch_size + + + + def maybe_split_prepared_input(self, cond): + """ + Process and potentially split the conditional input for Classifier-Free Guidance (CFG). + + This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`). + It determines whether to split the input based on its batch size relative to the expected batch size. + + Args: + cond (torch.Tensor): The conditional input tensor to process. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The negative conditional input (uncond_input) + - The positive conditional input (cond_input) + """ + if cond.shape[0] == self.batch_size * 2: + neg_cond = cond[0:self.batch_size] + cond = cond[self.batch_size:] + return neg_cond, cond + elif cond.shape[0] == self.batch_size: + return cond, cond + else: + raise ValueError(f"Unsupported input shape: {cond.shape}") + def prepare_input( self, - negative_cond_input: Union[torch.Tensor, List[torch.Tensor]], cond_input: Union[torch.Tensor, List[torch.Tensor]], - do_classifier_free_guidance: bool, + negative_cond_input: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, ) -> Union[torch.Tensor, List[torch.Tensor]]: """ Prepare the input for CFG. Args: - negative_cond_input (Union[torch.Tensor, List[torch.Tensor]]): The negative conditional input. It can be a - single tensor or a list of tensors. It must have the same length as `cond_input`. cond_input (Union[torch.Tensor, List[torch.Tensor]]): The conditional input. It can be a single tensor or a list of tensors. It must have the same length as `negative_cond_input`. - do_classifier_free_guidance (bool): Whether to perform classifier-free guidance. + negative_cond_input (Union[torch.Tensor, List[torch.Tensor]]): The negative conditional input. It can be a + single tensor or a list of tensors. It must have the same length as `cond_input`. Returns: Union[torch.Tensor, List[torch.Tensor]]: The prepared input. """ + + # If negative_cond_input is None, we check if cond_input already has CFG applied, and split if it is the case. + if negative_cond_input is None: + if isinstance(cond_input, list): + negative_cond_input, cond_input = zip(*[self.maybe_split_prepared_input(cond) for cond in cond_input]) + else: + negative_cond_input, cond_input = self.maybe_split_prepared_input(cond_input) + if isinstance(negative_cond_input, list) and isinstance(cond_input, list): if len(negative_cond_input) != len(cond_input): raise ValueError("The length of negative_cond_input and cond_input must be the same.") prepared_input = [] for neg_cond, cond in zip(negative_cond_input, cond_input): - if do_classifier_free_guidance: + if neg_cond.shape[0] != cond.shape[0]: + raise ValueError("The batch size of negative_cond_input and cond_input must be the same.") + if self.do_classifier_free_guidance: prepared_input.append(torch.cat([neg_cond, cond], dim=0)) else: prepared_input.append(cond) return prepared_input elif isinstance(negative_cond_input, torch.Tensor) and isinstance(cond_input, torch.Tensor): - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: return torch.cat([negative_cond_input, cond_input], dim=0) else: return cond_input @@ -323,18 +374,16 @@ def prepare_input( def apply_guidance( self, model_output: torch.Tensor, - guidance_scale: float, - do_classifier_free_guidance: bool, - guidance_rescale: float = 0.0, ) -> torch.Tensor: - if not do_classifier_free_guidance: + + if not self.do_classifier_free_guidance: return model_output noise_pred_uncond, noise_pred_text = model_output.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - if guidance_rescale > 0.0: + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + if self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) return noise_pred @@ -1823,6 +1872,7 @@ def inputs(self) -> List[Tuple[str, Any]]: ("cross_attention_kwargs", None), ("generator", None), ("eta", 0.0), + ("guider_kwargs", None), ] @property @@ -1857,7 +1907,9 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: cross_attention_kwargs = state.get_input("cross_attention_kwargs") generator = state.get_input("generator") eta = state.get_input("eta") + guider_kwargs = state.get_input("guider_kwargs") + batch_size = state.get_intermediate("batch_size") prompt_embeds = state.get_intermediate("prompt_embeds") negative_prompt_embeds = state.get_intermediate("negative_prompt_embeds") pooled_prompt_embeds = state.get_intermediate("pooled_prompt_embeds") @@ -1873,13 +1925,29 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: do_classifier_free_guidance = guidance_scale > 1.0 + # adding default guider arguments: do_classifier_free_guidance, guidance_scale, guidance_rescale + guider_kwargs = guider_kwargs or {} + guider_kwargs = { + **guider_kwargs, + "do_classifier_free_guidance": do_classifier_free_guidance, + "guidance_scale": guidance_scale, + "guidance_rescale": guidance_rescale, + "batch_size": batch_size, + } + + pipeline.guider.set_up_guider(guider_kwargs) # Prepare conditional inputs using the guider prompt_embeds = pipeline.guider.prepare_input( - negative_prompt_embeds, prompt_embeds, do_classifier_free_guidance + prompt_embeds, + negative_prompt_embeds, + ) + add_time_ids = pipeline.guider.prepare_input( + add_time_ids, + negative_add_time_ids, ) - add_time_ids = pipeline.guider.prepare_input(negative_add_time_ids, add_time_ids, do_classifier_free_guidance) pooled_prompt_embeds = pipeline.guider.prepare_input( - negative_pooled_prompt_embeds, pooled_prompt_embeds, do_classifier_free_guidance + pooled_prompt_embeds, + negative_pooled_prompt_embeds, ) added_cond_kwargs = { @@ -1894,7 +1962,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: with pipeline.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = pipeline.guider.prepare_input(latents, latents, do_classifier_free_guidance) + latent_model_input = pipeline.guider.prepare_input(latents) latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = pipeline.unet( @@ -1908,7 +1976,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: )[0] # perform guidance noise_pred = pipeline.guider.apply_guidance( - noise_pred, guidance_scale, do_classifier_free_guidance, guidance_rescale + noise_pred, ) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype @@ -1928,7 +1996,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class ControlNetDenoiseStep(PipelineBlock): required_components = ["unet", "controlnet", "scheduler"] - required_auxiliaries = ["guider", "control_image_processor"] + required_auxiliaries = ["guider", "controlnet_guider", "control_image_processor"] @property def inputs(self) -> List[Tuple[str, Any]]: @@ -1944,6 +2012,7 @@ def inputs(self) -> List[Tuple[str, Any]]: ("cross_attention_kwargs", None), ("generator", None), ("eta", 0.0), + ("guider_kwargs", None), ] @property @@ -1972,11 +2041,14 @@ def __init__( controlnet=None, scheduler=None, guider=None, + controlnet_guider=None, control_image_processor=None, vae_scale_factor=8.0, ): if guider is None: guider = CFGGuider() + if controlnet_guider is None: + controlnet_guider = CFGGuider() if control_image_processor is None: control_image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) super().__init__( @@ -1984,6 +2056,7 @@ def __init__( controlnet=controlnet, scheduler=scheduler, guider=guider, + controlnet_guider=controlnet_guider, control_image_processor=control_image_processor, vae_scale_factor=vae_scale_factor, ) @@ -1993,6 +2066,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: guidance_scale = state.get_input("guidance_scale") guidance_rescale = state.get_input("guidance_rescale") cross_attention_kwargs = state.get_input("cross_attention_kwargs") + guider_kwargs = state.get_input("guider_kwargs") generator = state.get_input("generator") eta = state.get_input("eta") num_images_per_prompt = state.get_input("num_images_per_prompt") @@ -2093,13 +2167,29 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: ] controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + # Prepare conditional inputs for unet using the guider + # adding default guider arguments: do_classifier_free_guidance, guidance_scale, guidance_rescale + guider_kwargs = guider_kwargs or {} + guider_kwargs = { + **guider_kwargs, + "do_classifier_free_guidance": do_classifier_free_guidance, + "guidance_scale": guidance_scale, + "guidance_rescale": guidance_rescale, + "batch_size": batch_size, + } + pipeline.guider.set_up_guider(guider_kwargs) prompt_embeds = pipeline.guider.prepare_input( - negative_prompt_embeds, prompt_embeds, do_classifier_free_guidance + prompt_embeds, + negative_prompt_embeds, + ) + add_time_ids = pipeline.guider.prepare_input( + add_time_ids, + negative_add_time_ids, ) - add_time_ids = pipeline.guider.prepare_input(negative_add_time_ids, add_time_ids, do_classifier_free_guidance) pooled_prompt_embeds = pipeline.guider.prepare_input( - negative_pooled_prompt_embeds, pooled_prompt_embeds, do_classifier_free_guidance + pooled_prompt_embeds, + negative_pooled_prompt_embeds, ) added_cond_kwargs = { @@ -2108,22 +2198,24 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: } # Prepare conditional inputs for controlnet using the guider - # common inputs: prompt_embeds, add_time_ids, pooled_prompt_embeds controlnet_do_classifier_free_guidance = do_classifier_free_guidance and not guess_mode - if do_classifier_free_guidance and not controlnet_do_classifier_free_guidance: - # when `guess_mode` and `do_classifier_free_guidance` are both enabled, we apply guidance for unet but not for controlnet - controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] - controlnet_added_cond_kwargs = { - "text_embeds": pooled_prompt_embeds.chunk(2)[1], - "time_ids": add_time_ids.chunk(2)[1], - } - else: - # if `guess_mode` is not enabled or `do_classifier_free_guidance` is not enabled, these inputs should be the same as unet - controlnet_prompt_embeds = prompt_embeds - controlnet_added_cond_kwargs = added_cond_kwargs + controlnet_guider_kwargs = guider_kwargs or {} + controlnet_guider_kwargs = { + **controlnet_guider_kwargs, + "do_classifier_free_guidance": controlnet_do_classifier_free_guidance, + "guidance_scale": guidance_scale, + "guidance_rescale": guidance_rescale, + "batch_size": batch_size, + } + pipeline.controlnet_guider.set_up_guider(controlnet_guider_kwargs) + controlnet_prompt_embeds = pipeline.controlnet_guider.prepare_input(prompt_embeds) + controlnet_added_cond_kwargs = { + "text_embeds": pipeline.controlnet_guider.prepare_input(pooled_prompt_embeds), + "time_ids": pipeline.controlnet_guider.prepare_input(add_time_ids), + } # controlnet-specific inputs: control_image - control_image = pipeline.guider.prepare_input( - control_image, control_image, controlnet_do_classifier_free_guidance + control_image = pipeline.controlnet_guider.prepare_input( + control_image ) # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline @@ -2133,16 +2225,10 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: with pipeline.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # prepare latents for unet using the guider - latent_model_input = pipeline.guider.prepare_input(latents, latents, do_classifier_free_guidance) - latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t) - # predict the noise residual + latent_model_input = pipeline.guider.prepare_input(latents) # prepare latents for controlnet using the guider - if do_classifier_free_guidance and not controlnet_do_classifier_free_guidance: - control_model_input = latents - control_model_input = pipeline.scheduler.scale_model_input(control_model_input, t) - else: - control_model_input = latent_model_input + control_model_input = pipeline.controlnet_guider.prepare_input(latents) if isinstance(controlnet_keep[i], list): cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] @@ -2152,7 +2238,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: controlnet_cond_scale = controlnet_cond_scale[0] cond_scale = controlnet_cond_scale * controlnet_keep[i] down_block_res_samples, mid_block_res_sample = pipeline.controlnet( - control_model_input, + pipeline.scheduler.scale_model_input(control_model_input, t), t, encoder_hidden_states=controlnet_prompt_embeds, controlnet_cond=control_image, @@ -2169,7 +2255,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) noise_pred = pipeline.unet( - latent_model_input, + pipeline.scheduler.scale_model_input(latent_model_input, t), t, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, @@ -2180,9 +2266,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: return_dict=False, )[0] # perform guidance - noise_pred = pipeline.guider.apply_guidance( - noise_pred, guidance_scale, do_classifier_free_guidance, guidance_rescale - ) + noise_pred = pipeline.guider.apply_guidance(noise_pred) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = pipeline.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] From 6742f160df4a293a409eb7d6d98c1ed30856bc12 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 27 Oct 2024 14:59:31 +0100 Subject: [PATCH 012/170] up --- src/diffusers/guider.py | 525 ++++++++++++++++++ .../pipelines/custom_pipeline_builder.py | 398 ++++++++++--- 2 files changed, 844 insertions(+), 79 deletions(-) create mode 100644 src/diffusers/guider.py diff --git a/src/diffusers/guider.py b/src/diffusers/guider.py new file mode 100644 index 000000000000..7a4b4e1b62c4 --- /dev/null +++ b/src/diffusers/guider.py @@ -0,0 +1,525 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple, Union + +import PIL +import torch +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +from .configuration_utils import ConfigMixin +from .image_processor import VaeImageProcessor +from .loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin +from .models import ControlNetModel, ImageProjection +from .models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor, Attention, AttentionProcessor, PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0 +from .models.lora import adjust_lora_scale_text_encoder +from .utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from .utils.torch_utils import is_compiled_module, randn_tensor +from .pipelines.controlnet.multicontrolnet import MultiControlNetModel +from .pipelines.pipeline_loading_utils import _fetch_class_library_tuple +from .pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin + + +import torch.nn as nn +import re + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + + +class CFGGuider: + """ + This class is used to guide the pipeline with CFG (Classifier-Free Guidance). + """ + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 and not self._disable_guidance + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def batch_size(self): + return self._batch_size + + def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]): + # a flag to disable CFG, e.g. we disable it for LCM and use a guidance scale embedding instead + disable_guidance = guider_kwargs.get("disable_guidance", False) + guidance_scale = guider_kwargs.get("guidance_scale", None) + if guidance_scale is None: + raise ValueError("guidance_scale is not provided in guider_kwargs") + guidance_rescale = guider_kwargs.get("guidance_rescale", 0.0) + batch_size = guider_kwargs.get("batch_size", None) + if batch_size is None: + raise ValueError("batch_size is not provided in guider_kwargs") + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._batch_size = batch_size + self._disable_guidance = disable_guidance + + def reset_guider(self, pipeline): + pass + + def maybe_update_guider(self, pipeline, timestep): + pass + + def maybe_update_input(self, pipeline, cond_input): + pass + + + def _maybe_split_prepared_input(self, cond): + """ + Process and potentially split the conditional input for Classifier-Free Guidance (CFG). + + This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`). + It determines whether to split the input based on its batch size relative to the expected batch size. + + Args: + cond (torch.Tensor): The conditional input tensor to process. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The negative conditional input (uncond_input) + - The positive conditional input (cond_input) + """ + if cond.shape[0] == self.batch_size * 2: + neg_cond = cond[0:self.batch_size] + cond = cond[self.batch_size:] + return neg_cond, cond + elif cond.shape[0] == self.batch_size: + return cond, cond + else: + raise ValueError(f"Unsupported input shape: {cond.shape}") + + + def _is_prepared_input(self, cond): + """ + Check if the input is already prepared for Classifier-Free Guidance (CFG). + + Args: + cond (torch.Tensor): The conditional input tensor to check. + + Returns: + bool: True if the input is already prepared, False otherwise. + """ + cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond + print(f"cond_tensor.shape[0]: {cond_tensor.shape[0]}") + print(f"self.batch_size: {self.batch_size}") + + return cond_tensor.shape[0] == self.batch_size * 2 + + + def prepare_input( + self, + cond_input: Union[torch.Tensor, List[torch.Tensor]], + negative_cond_input: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """ + Prepare the input for CFG. + + Args: + cond_input (Union[torch.Tensor, List[torch.Tensor]]): + The conditional input. It can be a single tensor or a + list of tensors. It must have the same length as `negative_cond_input`. + negative_cond_input (Union[torch.Tensor, List[torch.Tensor]]): The negative conditional input. It can be a + single tensor or a list of tensors. It must have the same length as `cond_input`. + + Returns: + Union[torch.Tensor, List[torch.Tensor]]: The prepared input. + """ + + # we check if cond_input already has CFG applied, and split if it is the case. + if self._is_prepared_input(cond_input) and self.do_classifier_free_guidance: + return cond_input + + if self._is_prepared_input(cond_input) and not self.do_classifier_free_guidance: + if isinstance(cond_input, list): + negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input]) + else: + negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input) + + if not self._is_prepared_input(cond_input) and negative_cond_input is None: + raise ValueError("`negative_cond_input` is required when cond_input does not already contains negative conditional input") + + if isinstance(cond_input, (list, tuple)): + + if not self.do_classifier_free_guidance: + return cond_input + + if len(negative_cond_input) != len(cond_input): + raise ValueError("The length of negative_cond_input and cond_input must be the same.") + prepared_input = [] + for neg_cond, cond in zip(negative_cond_input, cond_input): + if neg_cond.shape[0] != cond.shape[0]: + raise ValueError("The batch size of negative_cond_input and cond_input must be the same.") + prepared_input.append(torch.cat([neg_cond, cond], dim=0)) + return prepared_input + + elif isinstance(cond_input, torch.Tensor): + + if not self.do_classifier_free_guidance: + return cond_input + else: + return torch.cat([negative_cond_input, cond_input], dim=0) + + else: + raise ValueError(f"Unsupported input type: {type(cond_input)}") + + def apply_guidance( + self, + model_output: torch.Tensor, + timesteps: int = None, + ) -> torch.Tensor: + + if not self.do_classifier_free_guidance: + return model_output + + noise_pred_uncond, noise_pred_text = model_output.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + return noise_pred + + + + + +class PAGGuider: + """ + This class is used to guide the pipeline with CFG (Classifier-Free Guidance). + """ + + def __init__(self, + pag_applied_layers: Union[str, List[str]], + pag_attn_processors: Tuple[AttentionProcessor, AttentionProcessor] = ( + PAGCFGIdentitySelfAttnProcessor2_0(), + PAGIdentitySelfAttnProcessor2_0(), + ), + ): + r""" + Set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid. + + Args: + pag_applied_layers (`str` or `List[str]`): + One or more strings identifying the layer names, or a simple regex for matching multiple layers, where + PAG is to be applied. A few ways of expected usage are as follows: + - Single layers specified as - "blocks.{layer_index}" + - Multiple layers as a list - ["blocks.{layers_index_1}", "blocks.{layer_index_2}", ...] + - Multiple layers as a block name - "mid" + - Multiple layers as regex - "blocks.({layer_index_1}|{layer_index_2})" + pag_attn_processors: + (`Tuple[AttentionProcessor, AttentionProcessor]`, defaults to `(PAGCFGIdentitySelfAttnProcessor2_0(), + PAGIdentitySelfAttnProcessor2_0())`): A tuple of two attention processors. The first attention + processor is for PAG with Classifier-free guidance enabled (conditional and unconditional). The second + attention processor is for PAG with CFG disabled (unconditional only). + """ + + if not isinstance(pag_applied_layers, list): + pag_applied_layers = [pag_applied_layers] + if pag_attn_processors is not None: + if not isinstance(pag_attn_processors, tuple) or len(pag_attn_processors) != 2: + raise ValueError("Expected a tuple of two attention processors") + + for i in range(len(pag_applied_layers)): + if not isinstance(pag_applied_layers[i], str): + raise ValueError( + f"Expected either a string or a list of string but got type {type(pag_applied_layers[i])}" + ) + + self.pag_applied_layers = pag_applied_layers + self._pag_attn_processors = pag_attn_processors + + + def _set_pag_attn_processor(self, model, pag_applied_layers, do_classifier_free_guidance): + r""" + Set the attention processor for the PAG layers. + """ + pag_attn_processors = self._pag_attn_processors + pag_attn_proc = pag_attn_processors[0] if do_classifier_free_guidance else pag_attn_processors[1] + + def is_self_attn(module: nn.Module) -> bool: + r""" + Check if the module is self-attention module based on its name. + """ + return isinstance(module, Attention) and not module.is_cross_attention + + def is_fake_integral_match(layer_id, name): + layer_id = layer_id.split(".")[-1] + name = name.split(".")[-1] + return layer_id.isnumeric() and name.isnumeric() and layer_id == name + + for layer_id in pag_applied_layers: + # for each PAG layer input, we find corresponding self-attention layers in the unet model + target_modules = [] + + for name, module in model.named_modules(): + # Identify the following simple cases: + # (1) Self Attention layer existing + # (2) Whether the module name matches pag layer id even partially + # (3) Make sure it's not a fake integral match if the layer_id ends with a number + # For example, blocks.1, blocks.10 should be differentiable if layer_id="blocks.1" + if ( + is_self_attn(module) + and re.search(layer_id, name) is not None + and not is_fake_integral_match(layer_id, name) + ): + logger.debug(f"Applying PAG to layer: {name}") + target_modules.append(module) + + if len(target_modules) == 0: + raise ValueError(f"Cannot find PAG layer to set attention processor for: {layer_id}") + + for module in target_modules: + module.processor = pag_attn_proc + + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and not self._disable_guidance + + @property + def do_perturbed_attention_guidance(self): + return self._pag_scale > 0 and not self._disable_guidance + + @property + def do_pag_adaptive_scaling(self): + return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and not self._disable_guidance + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def batch_size(self): + return self._batch_size + + @property + def pag_scale(self): + return self._pag_scale + + @property + def pag_adaptive_scale(self): + return self._pag_adaptive_scale + + def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]): + + pag_scale = guider_kwargs.get("pag_scale", 3.0) + pag_adaptive_scale = guider_kwargs.get("pag_adaptive_scale", 0.0) + + batch_size = guider_kwargs.get("batch_size", None) + if batch_size is None: + raise ValueError("batch_size is a required argument for PAGGuider") + + guidance_scale = guider_kwargs.get("guidance_scale", None) + guidance_rescale = guider_kwargs.get("guidance_rescale", 0.0) + disable_guidance = guider_kwargs.get("disable_guidance", False) + + if guidance_scale is None: + raise ValueError("guidance_scale is a required argument for PAGGuider") + + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + self._guidance_scale = guidance_scale + self._disable_guidance = disable_guidance + self._guidance_rescale = guidance_rescale + self._batch_size = batch_size + if not hasattr(pipeline, "original_attn_proc") or pipeline.original_attn_proc is None: + self.original_attn_proc = pipeline.unet.attn_processors + self._set_pag_attn_processor( + model=pipeline.unet if hasattr(pipeline, "unet") else pipeline.transformer, + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + + def reset_guider(self, pipeline): + if self.do_perturbed_attention_guidance: + pipeline.unet.set_attn_processor(self.original_attn_proc) + pipeline.original_attn_proc = None + + + def maybe_update_guider(self, pipeline, timestep): + pass + + def maybe_update_input(self, pipeline, cond_input): + pass + + + def _is_prepared_input(self, cond): + """ + Check if the input is already prepared for Perturbed Attention Guidance (PAG). + + Args: + cond (torch.Tensor): The conditional input tensor to check. + + Returns: + bool: True if the input is already prepared, False otherwise. + """ + cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond + + return cond_tensor.shape[0] == self.batch_size * 3 + + def _maybe_split_prepared_input(self, cond): + """ + Process and potentially split the conditional input for Classifier-Free Guidance (CFG). + + This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`). + It determines whether to split the input based on its batch size relative to the expected batch size. + + Args: + cond (torch.Tensor): The conditional input tensor to process. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The negative conditional input (uncond_input) + - The positive conditional input (cond_input) + """ + if cond.shape[0] == self.batch_size * 3: + neg_cond = cond[0:self.batch_size] + cond = cond[self.batch_size:self.batch_size * 2] + return neg_cond, cond + elif cond.shape[0] == self.batch_size: + return cond, cond + else: + raise ValueError(f"Unsupported input shape: {cond.shape}") + + + def prepare_input( + self, + cond_input: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]], + negative_cond_input: Optional[Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]] = None, + ) -> Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]: + """ + Prepare the input for CFG. + + Args: + cond_input (Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]): + The conditional input. It can be a single tensor or a + list of tensors. It must have the same length as `negative_cond_input`. + negative_cond_input (Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]): The negative conditional input. It can be a + single tensor or a list of tensors. It must have the same length as `cond_input`. + + Returns: + Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]: The prepared input. + """ + + # we check if cond_input already has CFG applied, and split if it is the case. + + if self._is_prepared_input(cond_input) and self.do_perturbed_attention_guidance: + return cond_input + + if self._is_prepared_input(cond_input) and not self.do_perturbed_attention_guidance: + if isinstance(cond_input, list): + negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input]) + else: + negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input) + + if not self._is_prepared_input(cond_input) and negative_cond_input is None: + raise ValueError("`negative_cond_input` is required when cond_input does not already contains negative conditional input") + + if isinstance(cond_input, (list, tuple)): + + if not self.do_perturbed_attention_guidance: + return cond_input + + if len(negative_cond_input) != len(cond_input): + raise ValueError("The length of negative_cond_input and cond_input must be the same.") + + prepared_input = [] + for neg_cond, cond in zip(negative_cond_input, cond_input): + if neg_cond.shape[0] != cond.shape[0]: + raise ValueError("The batch size of negative_cond_input and cond_input must be the same.") + + cond = torch.cat([cond] * 2, dim=0) + if self.do_classifier_free_guidance: + prepared_input.append(torch.cat([neg_cond, cond], dim=0)) + else: + prepared_input.append(cond) + + return prepared_input + + elif isinstance(cond_input, torch.Tensor): + + if not self.do_perturbed_attention_guidance: + return cond_input + + cond_input = torch.cat([cond_input] * 2, dim=0) + if self.do_classifier_free_guidance: + return torch.cat([negative_cond_input, cond_input], dim=0) + else: + return cond_input + + else: + raise ValueError(f"Unsupported input type: {type(negative_cond_input)} and {type(cond_input)}") + + def apply_guidance( + self, + model_output: torch.Tensor, + timestep: int, + ) -> torch.Tensor: + + if not self.do_perturbed_attention_guidance: + return model_output + + if self.do_pag_adaptive_scaling: + pag_scale = max(self._pag_scale - self._pag_adaptive_scale * (1000 - timestep), 0) + else: + pag_scale = self._pag_scale + + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text, noise_pred_perturb = model_output.chunk(3) + noise_pred = ( + noise_pred_uncond + + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + pag_scale * (noise_pred_text - noise_pred_perturb) + ) + else: + noise_pred_text, noise_pred_perturb = model_output.chunk(2) + noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb) + + if self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + return noise_pred diff --git a/src/diffusers/pipelines/custom_pipeline_builder.py b/src/diffusers/pipelines/custom_pipeline_builder.py index 165f01473877..cf960c576ca4 100644 --- a/src/diffusers/pipelines/custom_pipeline_builder.py +++ b/src/diffusers/pipelines/custom_pipeline_builder.py @@ -282,11 +282,25 @@ class CFGGuider: This class is used to guide the pipeline with CFG (Classifier-Free Guidance). """ + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 and not self._disable_guidance - def set_up_guider(self, guider_kwargs: Dict[str, Any]): - do_classifier_free_guidance = guider_kwargs.get("do_classifier_free_guidance", None) - if do_classifier_free_guidance is None: - raise ValueError("do_classifier_free_guidance is not provided in guider_kwargs") + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def batch_size(self): + return self._batch_size + + def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]): + # a flag to disable CFG, e.g. we disable it for LCM and use a guidance scale embedding instead + disable_guidance = guider_kwargs.get("disable_guidance", False) guidance_scale = guider_kwargs.get("guidance_scale", None) if guidance_scale is None: raise ValueError("guidance_scale is not provided in guider_kwargs") @@ -294,14 +308,22 @@ def set_up_guider(self, guider_kwargs: Dict[str, Any]): batch_size = guider_kwargs.get("batch_size", None) if batch_size is None: raise ValueError("batch_size is not provided in guider_kwargs") - self.guidance_scale = guidance_scale - self.guidance_rescale = guidance_rescale - self.do_classifier_free_guidance = do_classifier_free_guidance - self.batch_size = batch_size + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._batch_size = batch_size + self._disable_guidance = disable_guidance + + def reset_guider(self, pipeline): + pass + + def maybe_update_guider(self, pipeline, timestep): + pass + def maybe_update_input(self, pipeline, cond_input): + pass - def maybe_split_prepared_input(self, cond): + def _maybe_split_prepared_input(self, cond): """ Process and potentially split the conditional input for Classifier-Free Guidance (CFG). @@ -325,6 +347,24 @@ def maybe_split_prepared_input(self, cond): else: raise ValueError(f"Unsupported input shape: {cond.shape}") + + def _is_prepared_input(self, cond): + """ + Check if the input is already prepared for Classifier-Free Guidance (CFG). + + Args: + cond (torch.Tensor): The conditional input tensor to check. + + Returns: + bool: True if the input is already prepared, False otherwise. + """ + cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond + print(f"cond_tensor.shape[0]: {cond_tensor.shape[0]}") + print(f"self.batch_size: {self.batch_size}") + + return cond_tensor.shape[0] == self.batch_size * 2 + + def prepare_input( self, cond_input: Union[torch.Tensor, List[torch.Tensor]], @@ -344,36 +384,47 @@ def prepare_input( Union[torch.Tensor, List[torch.Tensor]]: The prepared input. """ - # If negative_cond_input is None, we check if cond_input already has CFG applied, and split if it is the case. - if negative_cond_input is None: + # we check if cond_input already has CFG applied, and split if it is the case. + if self._is_prepared_input(cond_input) and self.do_classifier_free_guidance: + return cond_input + + if self._is_prepared_input(cond_input) and not self.do_classifier_free_guidance: if isinstance(cond_input, list): - negative_cond_input, cond_input = zip(*[self.maybe_split_prepared_input(cond) for cond in cond_input]) + negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input]) else: - negative_cond_input, cond_input = self.maybe_split_prepared_input(cond_input) - - if isinstance(negative_cond_input, list) and isinstance(cond_input, list): + negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input) + + if not self._is_prepared_input(cond_input) and negative_cond_input is None: + raise ValueError("`negative_cond_input` is required when cond_input does not already contains negative conditional input") + + if isinstance(cond_input, (list, tuple)): + + if not self.do_classifier_free_guidance: + return cond_input + if len(negative_cond_input) != len(cond_input): raise ValueError("The length of negative_cond_input and cond_input must be the same.") prepared_input = [] for neg_cond, cond in zip(negative_cond_input, cond_input): if neg_cond.shape[0] != cond.shape[0]: raise ValueError("The batch size of negative_cond_input and cond_input must be the same.") - if self.do_classifier_free_guidance: - prepared_input.append(torch.cat([neg_cond, cond], dim=0)) - else: - prepared_input.append(cond) + prepared_input.append(torch.cat([neg_cond, cond], dim=0)) return prepared_input - elif isinstance(negative_cond_input, torch.Tensor) and isinstance(cond_input, torch.Tensor): - if self.do_classifier_free_guidance: - return torch.cat([negative_cond_input, cond_input], dim=0) - else: + + elif isinstance(cond_input, torch.Tensor): + + if not self.do_classifier_free_guidance: return cond_input + else: + return torch.cat([negative_cond_input, cond_input], dim=0) + else: - raise ValueError(f"Unsupported input type: {type(negative_cond_input)} and {type(cond_input)}") + raise ValueError(f"Unsupported input type: {type(cond_input)}") def apply_guidance( self, model_output: torch.Tensor, + timesteps: int = None, ) -> torch.Tensor: if not self.do_classifier_free_guidance: @@ -381,6 +432,7 @@ def apply_guidance( noise_pred_uncond, noise_pred_text = model_output.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + if self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) @@ -524,8 +576,6 @@ def prepare_control_image( num_images_per_prompt, device, dtype, - do_classifier_free_guidance=False, - guess_mode=False, ): image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) image_batch_size = image.shape[0] @@ -1922,20 +1972,19 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: timesteps = state.get_intermediate("timesteps") num_inference_steps = state.get_intermediate("num_inference_steps") - - do_classifier_free_guidance = guidance_scale > 1.0 + disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False # adding default guider arguments: do_classifier_free_guidance, guidance_scale, guidance_rescale guider_kwargs = guider_kwargs or {} guider_kwargs = { **guider_kwargs, - "do_classifier_free_guidance": do_classifier_free_guidance, + "disable_guidance": disable_guidance, "guidance_scale": guidance_scale, "guidance_rescale": guidance_rescale, "batch_size": batch_size, } - pipeline.guider.set_up_guider(guider_kwargs) + pipeline.guider.set_guider(pipeline, guider_kwargs) # Prepare conditional inputs using the guider prompt_embeds = pipeline.guider.prepare_input( prompt_embeds, @@ -1962,7 +2011,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: with pipeline.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = pipeline.guider.prepare_input(latents) + latent_model_input = pipeline.guider.prepare_input(latents, latents) latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = pipeline.unet( @@ -1977,6 +2026,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: # perform guidance noise_pred = pipeline.guider.apply_guidance( noise_pred, + timestep = t, ) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype @@ -1989,6 +2039,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): progress_bar.update() + pipeline.guider.reset_guider(pipeline) state.add_intermediate("latents", latents) return pipeline, state @@ -2091,7 +2142,6 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: timestep_cond = state.get_intermediate("timestep_cond") - do_classifier_free_guidance = guidance_scale > 1.0 device = pipeline._execution_device height, width = latents.shape[-2:] @@ -2133,8 +2183,6 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, - guess_mode=guess_mode, ) elif isinstance(controlnet, MultiControlNetModel): control_images = [] @@ -2148,8 +2196,6 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, - guess_mode=guess_mode, ) control_images.append(control_image) @@ -2169,16 +2215,17 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: # Prepare conditional inputs for unet using the guider - # adding default guider arguments: do_classifier_free_guidance, guidance_scale, guidance_rescale + # adding default guider arguments: disable_guidance, guidance_scale, guidance_rescale + disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False guider_kwargs = guider_kwargs or {} guider_kwargs = { **guider_kwargs, - "do_classifier_free_guidance": do_classifier_free_guidance, + "disable_guidance": disable_guidance, "guidance_scale": guidance_scale, "guidance_rescale": guidance_rescale, "batch_size": batch_size, } - pipeline.guider.set_up_guider(guider_kwargs) + pipeline.guider.set_guider(pipeline, guider_kwargs) prompt_embeds = pipeline.guider.prepare_input( prompt_embeds, negative_prompt_embeds, @@ -2198,16 +2245,16 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: } # Prepare conditional inputs for controlnet using the guider - controlnet_do_classifier_free_guidance = do_classifier_free_guidance and not guess_mode + controlnet_disable_guidance = True if disable_guidance or guess_mode else False controlnet_guider_kwargs = guider_kwargs or {} - controlnet_guider_kwargs = { + controlnet_guider_kwargs = { **controlnet_guider_kwargs, - "do_classifier_free_guidance": controlnet_do_classifier_free_guidance, + "disable_guidance": controlnet_disable_guidance, "guidance_scale": guidance_scale, "guidance_rescale": guidance_rescale, "batch_size": batch_size, } - pipeline.controlnet_guider.set_up_guider(controlnet_guider_kwargs) + pipeline.controlnet_guider.set_guider(pipeline, controlnet_guider_kwargs) controlnet_prompt_embeds = pipeline.controlnet_guider.prepare_input(prompt_embeds) controlnet_added_cond_kwargs = { "text_embeds": pipeline.controlnet_guider.prepare_input(pooled_prompt_embeds), @@ -2215,7 +2262,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: } # controlnet-specific inputs: control_image control_image = pipeline.controlnet_guider.prepare_input( - control_image + control_image, control_image ) # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline @@ -2225,10 +2272,10 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: with pipeline.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # prepare latents for unet using the guider - latent_model_input = pipeline.guider.prepare_input(latents) + latent_model_input = pipeline.guider.prepare_input(latents, latents) # prepare latents for controlnet using the guider - control_model_input = pipeline.controlnet_guider.prepare_input(latents) + control_model_input = pipeline.controlnet_guider.prepare_input(latents, latents) if isinstance(controlnet_keep[i], list): cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] @@ -2248,11 +2295,10 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: return_dict=False, ) - if do_classifier_free_guidance and not controlnet_do_classifier_free_guidance: - # when we apply guidance for unet, but not for controlnet: - # add 0 to the unconditional batch - down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] - mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + # when we apply guidance for unet, but not for controlnet: + # add 0 to the unconditional batch + down_block_res_samples = pipeline.guider.prepare_input(down_block_res_samples, [torch.zeros_like(d) for d in down_block_res_samples]) + mid_block_res_sample = pipeline.guider.prepare_input(mid_block_res_sample, torch.zeros_like(mid_block_res_sample)) noise_pred = pipeline.unet( pipeline.scheduler.scale_model_input(latent_model_input, t), @@ -2266,7 +2312,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: return_dict=False, )[0] # perform guidance - noise_pred = pipeline.guider.apply_guidance(noise_pred) + noise_pred = pipeline.guider.apply_guidance(noise_pred, timestep=t) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = pipeline.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] @@ -2278,6 +2324,8 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): progress_bar.update() + pipeline.guider.reset_guider(pipeline) + pipeline.controlnet_guider.reset_guider(pipeline) state.add_intermediate("latents", latents) return pipeline, state @@ -2370,8 +2418,10 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: return pipeline, state -from diffusers.models.attention_processor import AttentionProcessor -from diffusers.models.attention_processor import PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0 +from diffusers.models.attention_processor import Attention, AttentionProcessor, PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0 +import torch.nn as nn +import re + class PAGGuider: """ @@ -2417,73 +2467,263 @@ def __init__(self, self.pag_applied_layers = pag_applied_layers self._pag_attn_processors = pag_attn_processors + + + def _set_pag_attn_processor(self, model, pag_applied_layers, do_classifier_free_guidance): + r""" + Set the attention processor for the PAG layers. + """ + pag_attn_processors = self._pag_attn_processors + pag_attn_proc = pag_attn_processors[0] if do_classifier_free_guidance else pag_attn_processors[1] + + def is_self_attn(module: nn.Module) -> bool: + r""" + Check if the module is self-attention module based on its name. + """ + return isinstance(module, Attention) and not module.is_cross_attention + + def is_fake_integral_match(layer_id, name): + layer_id = layer_id.split(".")[-1] + name = name.split(".")[-1] + return layer_id.isnumeric() and name.isnumeric() and layer_id == name + + for layer_id in pag_applied_layers: + # for each PAG layer input, we find corresponding self-attention layers in the unet model + target_modules = [] + + for name, module in model.named_modules(): + # Identify the following simple cases: + # (1) Self Attention layer existing + # (2) Whether the module name matches pag layer id even partially + # (3) Make sure it's not a fake integral match if the layer_id ends with a number + # For example, blocks.1, blocks.10 should be differentiable if layer_id="blocks.1" + if ( + is_self_attn(module) + and re.search(layer_id, name) is not None + and not is_fake_integral_match(layer_id, name) + ): + logger.debug(f"Applying PAG to layer: {name}") + target_modules.append(module) + + if len(target_modules) == 0: + raise ValueError(f"Cannot find PAG layer to set attention processor for: {layer_id}") + + for module in target_modules: + module.processor = pag_attn_proc + + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and not self._disable_guidance + + @property + def do_perturbed_attention_guidance(self): + return self._pag_scale > 0 and not self._disable_guidance + + @property + def do_pag_adaptive_scaling(self): + return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and not self._disable_guidance + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def batch_size(self): + return self._batch_size + + @property + def pag_scale(self): + return self._pag_scale + + @property + def pag_adaptive_scale(self): + return self._pag_adaptive_scale + + def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]): + + pag_scale = guider_kwargs.get("pag_scale", 3.0) + pag_adaptive_scale = guider_kwargs.get("pag_adaptive_scale", 0.0) + + batch_size = guider_kwargs.get("batch_size", None) + if batch_size is None: + raise ValueError("batch_size is a required argument for PAGGuider") + + guidance_scale = guider_kwargs.get("guidance_scale", None) + guidance_rescale = guider_kwargs.get("guidance_rescale", 0.0) + disable_guidance = guider_kwargs.get("disable_guidance", False) + + if guidance_scale is None: + raise ValueError("guidance_scale is a required argument for PAGGuider") + + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + self._guidance_scale = guidance_scale + self._disable_guidance = disable_guidance + self._guidance_rescale = guidance_rescale + self._batch_size = batch_size + if not hasattr(pipeline, "original_attn_proc") or pipeline.original_attn_proc is None: + self.original_attn_proc = pipeline.unet.attn_processors + self._set_pag_attn_processor( + model=pipeline.unet if hasattr(pipeline, "unet") else pipeline.transformer, + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + + def reset_guider(self, pipeline): + if self.do_perturbed_attention_guidance: + pipeline.unet.set_attn_processor(self.original_attn_proc) + pipeline.original_attn_proc = None + + + def maybe_update_guider(self, pipeline, timestep): + pass + + def maybe_update_input(self, pipeline, cond_input): + pass + + + def _is_prepared_input(self, cond): + """ + Check if the input is already prepared for Perturbed Attention Guidance (PAG). + + Args: + cond (torch.Tensor): The conditional input tensor to check. + + Returns: + bool: True if the input is already prepared, False otherwise. + """ + cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond + + return cond_tensor.shape[0] == self.batch_size * 3 + + def _maybe_split_prepared_input(self, cond): + """ + Process and potentially split the conditional input for Classifier-Free Guidance (CFG). + + This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`). + It determines whether to split the input based on its batch size relative to the expected batch size. + + Args: + cond (torch.Tensor): The conditional input tensor to process. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The negative conditional input (uncond_input) + - The positive conditional input (cond_input) + """ + if cond.shape[0] == self.batch_size * 3: + neg_cond = cond[0:self.batch_size] + cond = cond[self.batch_size:self.batch_size * 2] + return neg_cond, cond + elif cond.shape[0] == self.batch_size: + return cond, cond + else: + raise ValueError(f"Unsupported input shape: {cond.shape}") + def prepare_input( self, - negative_cond_input: Union[torch.Tensor, List[torch.Tensor]], - cond_input: Union[torch.Tensor, List[torch.Tensor]], - do_classifier_free_guidance: bool, - ) -> Union[torch.Tensor, List[torch.Tensor]]: + cond_input: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]], + negative_cond_input: Optional[Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]] = None, + ) -> Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]: """ Prepare the input for CFG. Args: - negative_cond_input (Union[torch.Tensor, List[torch.Tensor]]): The negative conditional input. It can be a - single tensor or a list of tensors. It must have the same length as `cond_input`. - cond_input (Union[torch.Tensor, List[torch.Tensor]]): + cond_input (Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]): The conditional input. It can be a single tensor or a list of tensors. It must have the same length as `negative_cond_input`. - do_classifier_free_guidance (bool): Whether to perform classifier-free guidance. + negative_cond_input (Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]): The negative conditional input. It can be a + single tensor or a list of tensors. It must have the same length as `cond_input`. Returns: - Union[torch.Tensor, List[torch.Tensor]]: The prepared input. + Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]: The prepared input. """ - if isinstance(negative_cond_input, list) and isinstance(cond_input, list): + + # we check if cond_input already has CFG applied, and split if it is the case. + + if self._is_prepared_input(cond_input) and self.do_perturbed_attention_guidance: + return cond_input + + if self._is_prepared_input(cond_input) and not self.do_perturbed_attention_guidance: + if isinstance(cond_input, list): + negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input]) + else: + negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input) + + if not self._is_prepared_input(cond_input) and negative_cond_input is None: + raise ValueError("`negative_cond_input` is required when cond_input does not already contains negative conditional input") + + if isinstance(cond_input, (list, tuple)): + + if not self.do_perturbed_attention_guidance: + return cond_input + if len(negative_cond_input) != len(cond_input): raise ValueError("The length of negative_cond_input and cond_input must be the same.") - + prepared_input = [] for neg_cond, cond in zip(negative_cond_input, cond_input): + if neg_cond.shape[0] != cond.shape[0]: + raise ValueError("The batch size of negative_cond_input and cond_input must be the same.") + cond = torch.cat([cond] * 2, dim=0) - - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: prepared_input.append(torch.cat([neg_cond, cond], dim=0)) else: prepared_input.append(cond) - + return prepared_input - elif isinstance(negative_cond_input, torch.Tensor) and isinstance(cond_input, torch.Tensor): - cond_input = torch.cat([cond_input] * 2, dim=0) - if do_classifier_free_guidance: + elif isinstance(cond_input, torch.Tensor): + + if not self.do_perturbed_attention_guidance: + return cond_input + + cond_input = torch.cat([cond_input] * 2, dim=0) + if self.do_classifier_free_guidance: return torch.cat([negative_cond_input, cond_input], dim=0) else: return cond_input + else: raise ValueError(f"Unsupported input type: {type(negative_cond_input)} and {type(cond_input)}") def apply_guidance( self, model_output: torch.Tensor, - pag_scale: float, - guidance_scale: float, - do_classifier_free_guidance: bool, - guidance_rescale: float = 0.0, + timestep: int, ) -> torch.Tensor: - if do_classifier_free_guidance: + + if not self.do_perturbed_attention_guidance: + return model_output + + if self.do_pag_adaptive_scaling: + pag_scale = max(self._pag_scale - self._pag_adaptive_scale * (1000 - timestep), 0) + else: + pag_scale = self._pag_scale + + if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text, noise_pred_perturb = model_output.chunk(3) noise_pred = ( noise_pred_uncond - + guidance_scale * (noise_pred_text - noise_pred_uncond) + + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + pag_scale * (noise_pred_text - noise_pred_perturb) ) else: noise_pred_text, noise_pred_perturb = model_output.chunk(2) noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb) - if guidance_rescale > 0.0: + if self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + return noise_pred From 005195c23e81a217e63c4913a9e196517807a194 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 27 Oct 2024 15:18:10 +0100 Subject: [PATCH 013/170] add --- src/diffusers/guider.py | 168 +++--- .../pipelines/custom_pipeline_builder.py | 516 +----------------- 2 files changed, 87 insertions(+), 597 deletions(-) diff --git a/src/diffusers/guider.py b/src/diffusers/guider.py index 7a4b4e1b62c4..5369bd25f9e2 100644 --- a/src/diffusers/guider.py +++ b/src/diffusers/guider.py @@ -12,34 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect -from dataclasses import dataclass, field -from enum import Enum +import re from typing import Any, Dict, List, Optional, Tuple, Union -import PIL import torch -from tqdm.auto import tqdm -from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer - -from .configuration_utils import ConfigMixin -from .image_processor import VaeImageProcessor -from .loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin -from .models import ControlNetModel, ImageProjection -from .models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor, Attention, AttentionProcessor, PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0 -from .models.lora import adjust_lora_scale_text_encoder -from .utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers -from .utils.torch_utils import is_compiled_module, randn_tensor -from .pipelines.controlnet.multicontrolnet import MultiControlNetModel -from .pipelines.pipeline_loading_utils import _fetch_class_library_tuple -from .pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin +import torch.nn as nn +from .models.attention_processor import ( + Attention, + AttentionProcessor, + PAGCFGIdentitySelfAttnProcessor2_0, + PAGIdentitySelfAttnProcessor2_0, +) +from .utils import logging -import torch.nn as nn -import re logger = logging.get_logger(__name__) # pylint: disable=invalid-name + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ @@ -55,7 +45,6 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): return noise_cfg - class CFGGuider: """ This class is used to guide the pipeline with CFG (Classifier-Free Guidance). @@ -64,7 +53,7 @@ class CFGGuider: @property def do_classifier_free_guidance(self): return self._guidance_scale > 1.0 and not self._disable_guidance - + @property def guidance_rescale(self): return self._guidance_rescale @@ -76,7 +65,7 @@ def guidance_scale(self): @property def batch_size(self): return self._batch_size - + def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]): # a flag to disable CFG, e.g. we disable it for LCM and use a guidance scale embedding instead disable_guidance = guider_kwargs.get("disable_guidance", False) @@ -95,18 +84,17 @@ def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]): def reset_guider(self, pipeline): pass - def maybe_update_guider(self, pipeline, timestep): + def maybe_update_guider(self, pipeline, timestep): pass def maybe_update_input(self, pipeline, cond_input): pass - - + def _maybe_split_prepared_input(self, cond): """ Process and potentially split the conditional input for Classifier-Free Guidance (CFG). - This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`). + This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`). It determines whether to split the input based on its batch size relative to the expected batch size. Args: @@ -118,15 +106,14 @@ def _maybe_split_prepared_input(self, cond): - The positive conditional input (cond_input) """ if cond.shape[0] == self.batch_size * 2: - neg_cond = cond[0:self.batch_size] - cond = cond[self.batch_size:] + neg_cond = cond[0 : self.batch_size] + cond = cond[self.batch_size :] return neg_cond, cond elif cond.shape[0] == self.batch_size: return cond, cond else: raise ValueError(f"Unsupported input shape: {cond.shape}") - - + def _is_prepared_input(self, cond): """ Check if the input is already prepared for Classifier-Free Guidance (CFG). @@ -138,12 +125,9 @@ def _is_prepared_input(self, cond): bool: True if the input is already prepared, False otherwise. """ cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond - print(f"cond_tensor.shape[0]: {cond_tensor.shape[0]}") - print(f"self.batch_size: {self.batch_size}") return cond_tensor.shape[0] == self.batch_size * 2 - def prepare_input( self, cond_input: Union[torch.Tensor, List[torch.Tensor]], @@ -157,7 +141,7 @@ def prepare_input( The conditional input. It can be a single tensor or a list of tensors. It must have the same length as `negative_cond_input`. negative_cond_input (Union[torch.Tensor, List[torch.Tensor]]): The negative conditional input. It can be a - single tensor or a list of tensors. It must have the same length as `cond_input`. + single tensor or a list of tensors. It must have the same length as `cond_input`. Returns: Union[torch.Tensor, List[torch.Tensor]]: The prepared input. @@ -166,21 +150,22 @@ def prepare_input( # we check if cond_input already has CFG applied, and split if it is the case. if self._is_prepared_input(cond_input) and self.do_classifier_free_guidance: return cond_input - + if self._is_prepared_input(cond_input) and not self.do_classifier_free_guidance: if isinstance(cond_input, list): negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input]) else: negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input) - + if not self._is_prepared_input(cond_input) and negative_cond_input is None: - raise ValueError("`negative_cond_input` is required when cond_input does not already contains negative conditional input") + raise ValueError( + "`negative_cond_input` is required when cond_input does not already contains negative conditional input" + ) if isinstance(cond_input, (list, tuple)): - if not self.do_classifier_free_guidance: return cond_input - + if len(negative_cond_input) != len(cond_input): raise ValueError("The length of negative_cond_input and cond_input must be the same.") prepared_input = [] @@ -189,9 +174,8 @@ def prepare_input( raise ValueError("The batch size of negative_cond_input and cond_input must be the same.") prepared_input.append(torch.cat([neg_cond, cond], dim=0)) return prepared_input - + elif isinstance(cond_input, torch.Tensor): - if not self.do_classifier_free_guidance: return cond_input else: @@ -203,35 +187,32 @@ def prepare_input( def apply_guidance( self, model_output: torch.Tensor, - timesteps: int = None, + timestep: int = None, ) -> torch.Tensor: - if not self.do_classifier_free_guidance: return model_output noise_pred_uncond, noise_pred_text = model_output.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - + if self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) return noise_pred - - - class PAGGuider: """ This class is used to guide the pipeline with CFG (Classifier-Free Guidance). """ - def __init__(self, - pag_applied_layers: Union[str, List[str]], - pag_attn_processors: Tuple[AttentionProcessor, AttentionProcessor] = ( - PAGCFGIdentitySelfAttnProcessor2_0(), - PAGIdentitySelfAttnProcessor2_0(), - ), + def __init__( + self, + pag_applied_layers: Union[str, List[str]], + pag_attn_processors: Tuple[AttentionProcessor, AttentionProcessor] = ( + PAGCFGIdentitySelfAttnProcessor2_0(), + PAGIdentitySelfAttnProcessor2_0(), + ), ): r""" Set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid. @@ -265,7 +246,6 @@ def __init__(self, self.pag_applied_layers = pag_applied_layers self._pag_attn_processors = pag_attn_processors - def _set_pag_attn_processor(self, model, pag_applied_layers, do_classifier_free_guidance): r""" @@ -308,42 +288,40 @@ def is_fake_integral_match(layer_id, name): for module in target_modules: module.processor = pag_attn_proc - @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 and not self._disable_guidance - + @property def do_perturbed_attention_guidance(self): return self._pag_scale > 0 and not self._disable_guidance - + @property def do_pag_adaptive_scaling(self): return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and not self._disable_guidance - + @property def guidance_scale(self): return self._guidance_scale - + @property def guidance_rescale(self): return self._guidance_rescale - + @property def batch_size(self): return self._batch_size - + @property def pag_scale(self): return self._pag_scale - + @property def pag_adaptive_scale(self): return self._pag_adaptive_scale - - def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]): + def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]): pag_scale = guider_kwargs.get("pag_scale", 3.0) pag_adaptive_scale = guider_kwargs.get("pag_adaptive_scale", 0.0) @@ -367,24 +345,21 @@ def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]): if not hasattr(pipeline, "original_attn_proc") or pipeline.original_attn_proc is None: self.original_attn_proc = pipeline.unet.attn_processors self._set_pag_attn_processor( - model=pipeline.unet if hasattr(pipeline, "unet") else pipeline.transformer, - pag_applied_layers=self.pag_applied_layers, - do_classifier_free_guidance=self.do_classifier_free_guidance, - ) - + model=pipeline.unet if hasattr(pipeline, "unet") else pipeline.transformer, + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) def reset_guider(self, pipeline): if self.do_perturbed_attention_guidance: pipeline.unet.set_attn_processor(self.original_attn_proc) pipeline.original_attn_proc = None - - def maybe_update_guider(self, pipeline, timestep): + def maybe_update_guider(self, pipeline, timestep): pass def maybe_update_input(self, pipeline, cond_input): pass - def _is_prepared_input(self, cond): """ @@ -399,12 +374,12 @@ def _is_prepared_input(self, cond): cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond return cond_tensor.shape[0] == self.batch_size * 3 - + def _maybe_split_prepared_input(self, cond): """ Process and potentially split the conditional input for Classifier-Free Guidance (CFG). - This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`). + This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`). It determines whether to split the input based on its batch size relative to the expected batch size. Args: @@ -416,15 +391,14 @@ def _maybe_split_prepared_input(self, cond): - The positive conditional input (cond_input) """ if cond.shape[0] == self.batch_size * 3: - neg_cond = cond[0:self.batch_size] - cond = cond[self.batch_size:self.batch_size * 2] + neg_cond = cond[0 : self.batch_size] + cond = cond[self.batch_size : self.batch_size * 2] return neg_cond, cond elif cond.shape[0] == self.batch_size: return cond, cond else: raise ValueError(f"Unsupported input shape: {cond.shape}") - - + def prepare_input( self, cond_input: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]], @@ -437,59 +411,60 @@ def prepare_input( cond_input (Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]): The conditional input. It can be a single tensor or a list of tensors. It must have the same length as `negative_cond_input`. - negative_cond_input (Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]): The negative conditional input. It can be a - single tensor or a list of tensors. It must have the same length as `cond_input`. + negative_cond_input (Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]): + The negative conditional input. It can be a single tensor or a list of tensors. It must have the same + length as `cond_input`. Returns: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]: The prepared input. """ # we check if cond_input already has CFG applied, and split if it is the case. - + if self._is_prepared_input(cond_input) and self.do_perturbed_attention_guidance: return cond_input - + if self._is_prepared_input(cond_input) and not self.do_perturbed_attention_guidance: if isinstance(cond_input, list): negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input]) else: negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input) - + if not self._is_prepared_input(cond_input) and negative_cond_input is None: - raise ValueError("`negative_cond_input` is required when cond_input does not already contains negative conditional input") - + raise ValueError( + "`negative_cond_input` is required when cond_input does not already contains negative conditional input" + ) + if isinstance(cond_input, (list, tuple)): - if not self.do_perturbed_attention_guidance: return cond_input - + if len(negative_cond_input) != len(cond_input): raise ValueError("The length of negative_cond_input and cond_input must be the same.") - + prepared_input = [] for neg_cond, cond in zip(negative_cond_input, cond_input): if neg_cond.shape[0] != cond.shape[0]: raise ValueError("The batch size of negative_cond_input and cond_input must be the same.") - + cond = torch.cat([cond] * 2, dim=0) if self.do_classifier_free_guidance: prepared_input.append(torch.cat([neg_cond, cond], dim=0)) else: prepared_input.append(cond) - + return prepared_input elif isinstance(cond_input, torch.Tensor): - if not self.do_perturbed_attention_guidance: return cond_input - + cond_input = torch.cat([cond_input] * 2, dim=0) if self.do_classifier_free_guidance: return torch.cat([negative_cond_input, cond_input], dim=0) else: return cond_input - + else: raise ValueError(f"Unsupported input type: {type(negative_cond_input)} and {type(cond_input)}") @@ -498,10 +473,9 @@ def apply_guidance( model_output: torch.Tensor, timestep: int, ) -> torch.Tensor: - if not self.do_perturbed_attention_guidance: return model_output - + if self.do_pag_adaptive_scaling: pag_scale = max(self._pag_scale - self._pag_adaptive_scale * (1000 - timestep), 0) else: @@ -521,5 +495,5 @@ def apply_guidance( if self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) - + return noise_pred diff --git a/src/diffusers/pipelines/custom_pipeline_builder.py b/src/diffusers/pipelines/custom_pipeline_builder.py index cf960c576ca4..35e3e3da3eae 100644 --- a/src/diffusers/pipelines/custom_pipeline_builder.py +++ b/src/diffusers/pipelines/custom_pipeline_builder.py @@ -23,6 +23,7 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from ..configuration_utils import ConfigMixin +from ..guider import CFGGuider from ..image_processor import VaeImageProcessor from ..loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin from ..models import ControlNetModel, ImageProjection @@ -262,183 +263,6 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg -def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): - """ - Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 - """ - std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) - std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) - # rescale the results from guidance (fixes overexposure) - noise_pred_rescaled = noise_cfg * (std_text / std_cfg) - # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images - noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg - return noise_cfg - - -class CFGGuider: - """ - This class is used to guide the pipeline with CFG (Classifier-Free Guidance). - """ - - @property - def do_classifier_free_guidance(self): - return self._guidance_scale > 1.0 and not self._disable_guidance - - @property - def guidance_rescale(self): - return self._guidance_rescale - - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def batch_size(self): - return self._batch_size - - def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]): - # a flag to disable CFG, e.g. we disable it for LCM and use a guidance scale embedding instead - disable_guidance = guider_kwargs.get("disable_guidance", False) - guidance_scale = guider_kwargs.get("guidance_scale", None) - if guidance_scale is None: - raise ValueError("guidance_scale is not provided in guider_kwargs") - guidance_rescale = guider_kwargs.get("guidance_rescale", 0.0) - batch_size = guider_kwargs.get("batch_size", None) - if batch_size is None: - raise ValueError("batch_size is not provided in guider_kwargs") - self._guidance_scale = guidance_scale - self._guidance_rescale = guidance_rescale - self._batch_size = batch_size - self._disable_guidance = disable_guidance - - def reset_guider(self, pipeline): - pass - - def maybe_update_guider(self, pipeline, timestep): - pass - - def maybe_update_input(self, pipeline, cond_input): - pass - - - def _maybe_split_prepared_input(self, cond): - """ - Process and potentially split the conditional input for Classifier-Free Guidance (CFG). - - This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`). - It determines whether to split the input based on its batch size relative to the expected batch size. - - Args: - cond (torch.Tensor): The conditional input tensor to process. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - The negative conditional input (uncond_input) - - The positive conditional input (cond_input) - """ - if cond.shape[0] == self.batch_size * 2: - neg_cond = cond[0:self.batch_size] - cond = cond[self.batch_size:] - return neg_cond, cond - elif cond.shape[0] == self.batch_size: - return cond, cond - else: - raise ValueError(f"Unsupported input shape: {cond.shape}") - - - def _is_prepared_input(self, cond): - """ - Check if the input is already prepared for Classifier-Free Guidance (CFG). - - Args: - cond (torch.Tensor): The conditional input tensor to check. - - Returns: - bool: True if the input is already prepared, False otherwise. - """ - cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond - print(f"cond_tensor.shape[0]: {cond_tensor.shape[0]}") - print(f"self.batch_size: {self.batch_size}") - - return cond_tensor.shape[0] == self.batch_size * 2 - - - def prepare_input( - self, - cond_input: Union[torch.Tensor, List[torch.Tensor]], - negative_cond_input: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, - ) -> Union[torch.Tensor, List[torch.Tensor]]: - """ - Prepare the input for CFG. - - Args: - cond_input (Union[torch.Tensor, List[torch.Tensor]]): - The conditional input. It can be a single tensor or a - list of tensors. It must have the same length as `negative_cond_input`. - negative_cond_input (Union[torch.Tensor, List[torch.Tensor]]): The negative conditional input. It can be a - single tensor or a list of tensors. It must have the same length as `cond_input`. - - Returns: - Union[torch.Tensor, List[torch.Tensor]]: The prepared input. - """ - - # we check if cond_input already has CFG applied, and split if it is the case. - if self._is_prepared_input(cond_input) and self.do_classifier_free_guidance: - return cond_input - - if self._is_prepared_input(cond_input) and not self.do_classifier_free_guidance: - if isinstance(cond_input, list): - negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input]) - else: - negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input) - - if not self._is_prepared_input(cond_input) and negative_cond_input is None: - raise ValueError("`negative_cond_input` is required when cond_input does not already contains negative conditional input") - - if isinstance(cond_input, (list, tuple)): - - if not self.do_classifier_free_guidance: - return cond_input - - if len(negative_cond_input) != len(cond_input): - raise ValueError("The length of negative_cond_input and cond_input must be the same.") - prepared_input = [] - for neg_cond, cond in zip(negative_cond_input, cond_input): - if neg_cond.shape[0] != cond.shape[0]: - raise ValueError("The batch size of negative_cond_input and cond_input must be the same.") - prepared_input.append(torch.cat([neg_cond, cond], dim=0)) - return prepared_input - - elif isinstance(cond_input, torch.Tensor): - - if not self.do_classifier_free_guidance: - return cond_input - else: - return torch.cat([negative_cond_input, cond_input], dim=0) - - else: - raise ValueError(f"Unsupported input type: {type(cond_input)}") - - def apply_guidance( - self, - model_output: torch.Tensor, - timesteps: int = None, - ) -> torch.Tensor: - - if not self.do_classifier_free_guidance: - return model_output - - noise_pred_uncond, noise_pred_text = model_output.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - - if self.guidance_rescale > 0.0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) - return noise_pred - - class SDXLCustomPipeline( CustomPipeline, StableDiffusionMixin, @@ -1988,15 +1812,15 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: # Prepare conditional inputs using the guider prompt_embeds = pipeline.guider.prepare_input( prompt_embeds, - negative_prompt_embeds, + negative_prompt_embeds, ) add_time_ids = pipeline.guider.prepare_input( add_time_ids, - negative_add_time_ids, + negative_add_time_ids, ) pooled_prompt_embeds = pipeline.guider.prepare_input( - pooled_prompt_embeds, - negative_pooled_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, ) added_cond_kwargs = { @@ -2025,8 +1849,8 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: )[0] # perform guidance noise_pred = pipeline.guider.apply_guidance( - noise_pred, - timestep = t, + noise_pred, + timestep=t, ) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype @@ -2213,7 +2037,6 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: ] controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) - # Prepare conditional inputs for unet using the guider # adding default guider arguments: disable_guidance, guidance_scale, guidance_rescale disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False @@ -2232,7 +2055,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: ) add_time_ids = pipeline.guider.prepare_input( add_time_ids, - negative_add_time_ids, + negative_add_time_ids, ) pooled_prompt_embeds = pipeline.guider.prepare_input( pooled_prompt_embeds, @@ -2247,7 +2070,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: # Prepare conditional inputs for controlnet using the guider controlnet_disable_guidance = True if disable_guidance or guess_mode else False controlnet_guider_kwargs = guider_kwargs or {} - controlnet_guider_kwargs = { + controlnet_guider_kwargs = { **controlnet_guider_kwargs, "disable_guidance": controlnet_disable_guidance, "guidance_scale": guidance_scale, @@ -2261,9 +2084,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: "time_ids": pipeline.controlnet_guider.prepare_input(add_time_ids), } # controlnet-specific inputs: control_image - control_image = pipeline.controlnet_guider.prepare_input( - control_image, control_image - ) + control_image = pipeline.controlnet_guider.prepare_input(control_image, control_image) # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta) @@ -2297,8 +2118,12 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: # when we apply guidance for unet, but not for controlnet: # add 0 to the unconditional batch - down_block_res_samples = pipeline.guider.prepare_input(down_block_res_samples, [torch.zeros_like(d) for d in down_block_res_samples]) - mid_block_res_sample = pipeline.guider.prepare_input(mid_block_res_sample, torch.zeros_like(mid_block_res_sample)) + down_block_res_samples = pipeline.guider.prepare_input( + down_block_res_samples, [torch.zeros_like(d) for d in down_block_res_samples] + ) + mid_block_res_sample = pipeline.guider.prepare_input( + mid_block_res_sample, torch.zeros_like(mid_block_res_sample) + ) noise_pred = pipeline.unet( pipeline.scheduler.scale_model_input(latent_model_input, t), @@ -2418,315 +2243,6 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: return pipeline, state -from diffusers.models.attention_processor import Attention, AttentionProcessor, PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0 -import torch.nn as nn -import re - - -class PAGGuider: - """ - This class is used to guide the pipeline with CFG (Classifier-Free Guidance). - """ - - def __init__(self, - pag_applied_layers: Union[str, List[str]], - pag_attn_processors: Tuple[AttentionProcessor, AttentionProcessor] = ( - PAGCFGIdentitySelfAttnProcessor2_0(), - PAGIdentitySelfAttnProcessor2_0(), - ), - ): - r""" - Set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid. - - Args: - pag_applied_layers (`str` or `List[str]`): - One or more strings identifying the layer names, or a simple regex for matching multiple layers, where - PAG is to be applied. A few ways of expected usage are as follows: - - Single layers specified as - "blocks.{layer_index}" - - Multiple layers as a list - ["blocks.{layers_index_1}", "blocks.{layer_index_2}", ...] - - Multiple layers as a block name - "mid" - - Multiple layers as regex - "blocks.({layer_index_1}|{layer_index_2})" - pag_attn_processors: - (`Tuple[AttentionProcessor, AttentionProcessor]`, defaults to `(PAGCFGIdentitySelfAttnProcessor2_0(), - PAGIdentitySelfAttnProcessor2_0())`): A tuple of two attention processors. The first attention - processor is for PAG with Classifier-free guidance enabled (conditional and unconditional). The second - attention processor is for PAG with CFG disabled (unconditional only). - """ - - if not isinstance(pag_applied_layers, list): - pag_applied_layers = [pag_applied_layers] - if pag_attn_processors is not None: - if not isinstance(pag_attn_processors, tuple) or len(pag_attn_processors) != 2: - raise ValueError("Expected a tuple of two attention processors") - - for i in range(len(pag_applied_layers)): - if not isinstance(pag_applied_layers[i], str): - raise ValueError( - f"Expected either a string or a list of string but got type {type(pag_applied_layers[i])}" - ) - - self.pag_applied_layers = pag_applied_layers - self._pag_attn_processors = pag_attn_processors - - - def _set_pag_attn_processor(self, model, pag_applied_layers, do_classifier_free_guidance): - r""" - Set the attention processor for the PAG layers. - """ - pag_attn_processors = self._pag_attn_processors - pag_attn_proc = pag_attn_processors[0] if do_classifier_free_guidance else pag_attn_processors[1] - - def is_self_attn(module: nn.Module) -> bool: - r""" - Check if the module is self-attention module based on its name. - """ - return isinstance(module, Attention) and not module.is_cross_attention - - def is_fake_integral_match(layer_id, name): - layer_id = layer_id.split(".")[-1] - name = name.split(".")[-1] - return layer_id.isnumeric() and name.isnumeric() and layer_id == name - - for layer_id in pag_applied_layers: - # for each PAG layer input, we find corresponding self-attention layers in the unet model - target_modules = [] - - for name, module in model.named_modules(): - # Identify the following simple cases: - # (1) Self Attention layer existing - # (2) Whether the module name matches pag layer id even partially - # (3) Make sure it's not a fake integral match if the layer_id ends with a number - # For example, blocks.1, blocks.10 should be differentiable if layer_id="blocks.1" - if ( - is_self_attn(module) - and re.search(layer_id, name) is not None - and not is_fake_integral_match(layer_id, name) - ): - logger.debug(f"Applying PAG to layer: {name}") - target_modules.append(module) - - if len(target_modules) == 0: - raise ValueError(f"Cannot find PAG layer to set attention processor for: {layer_id}") - - for module in target_modules: - module.processor = pag_attn_proc - - - @property - def do_classifier_free_guidance(self): - return self._guidance_scale > 1 and not self._disable_guidance - - @property - def do_perturbed_attention_guidance(self): - return self._pag_scale > 0 and not self._disable_guidance - - @property - def do_pag_adaptive_scaling(self): - return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and not self._disable_guidance - - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def guidance_rescale(self): - return self._guidance_rescale - - @property - def batch_size(self): - return self._batch_size - - @property - def pag_scale(self): - return self._pag_scale - - @property - def pag_adaptive_scale(self): - return self._pag_adaptive_scale - - def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]): - - pag_scale = guider_kwargs.get("pag_scale", 3.0) - pag_adaptive_scale = guider_kwargs.get("pag_adaptive_scale", 0.0) - - batch_size = guider_kwargs.get("batch_size", None) - if batch_size is None: - raise ValueError("batch_size is a required argument for PAGGuider") - - guidance_scale = guider_kwargs.get("guidance_scale", None) - guidance_rescale = guider_kwargs.get("guidance_rescale", 0.0) - disable_guidance = guider_kwargs.get("disable_guidance", False) - - if guidance_scale is None: - raise ValueError("guidance_scale is a required argument for PAGGuider") - - self._pag_scale = pag_scale - self._pag_adaptive_scale = pag_adaptive_scale - self._guidance_scale = guidance_scale - self._disable_guidance = disable_guidance - self._guidance_rescale = guidance_rescale - self._batch_size = batch_size - if not hasattr(pipeline, "original_attn_proc") or pipeline.original_attn_proc is None: - self.original_attn_proc = pipeline.unet.attn_processors - self._set_pag_attn_processor( - model=pipeline.unet if hasattr(pipeline, "unet") else pipeline.transformer, - pag_applied_layers=self.pag_applied_layers, - do_classifier_free_guidance=self.do_classifier_free_guidance, - ) - - - def reset_guider(self, pipeline): - if self.do_perturbed_attention_guidance: - pipeline.unet.set_attn_processor(self.original_attn_proc) - pipeline.original_attn_proc = None - - - def maybe_update_guider(self, pipeline, timestep): - pass - - def maybe_update_input(self, pipeline, cond_input): - pass - - - def _is_prepared_input(self, cond): - """ - Check if the input is already prepared for Perturbed Attention Guidance (PAG). - - Args: - cond (torch.Tensor): The conditional input tensor to check. - - Returns: - bool: True if the input is already prepared, False otherwise. - """ - cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond - - return cond_tensor.shape[0] == self.batch_size * 3 - - def _maybe_split_prepared_input(self, cond): - """ - Process and potentially split the conditional input for Classifier-Free Guidance (CFG). - - This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`). - It determines whether to split the input based on its batch size relative to the expected batch size. - - Args: - cond (torch.Tensor): The conditional input tensor to process. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - The negative conditional input (uncond_input) - - The positive conditional input (cond_input) - """ - if cond.shape[0] == self.batch_size * 3: - neg_cond = cond[0:self.batch_size] - cond = cond[self.batch_size:self.batch_size * 2] - return neg_cond, cond - elif cond.shape[0] == self.batch_size: - return cond, cond - else: - raise ValueError(f"Unsupported input shape: {cond.shape}") - - - def prepare_input( - self, - cond_input: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]], - negative_cond_input: Optional[Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]] = None, - ) -> Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]: - """ - Prepare the input for CFG. - - Args: - cond_input (Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]): - The conditional input. It can be a single tensor or a - list of tensors. It must have the same length as `negative_cond_input`. - negative_cond_input (Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]): The negative conditional input. It can be a - single tensor or a list of tensors. It must have the same length as `cond_input`. - - Returns: - Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]: The prepared input. - """ - - # we check if cond_input already has CFG applied, and split if it is the case. - - if self._is_prepared_input(cond_input) and self.do_perturbed_attention_guidance: - return cond_input - - if self._is_prepared_input(cond_input) and not self.do_perturbed_attention_guidance: - if isinstance(cond_input, list): - negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input]) - else: - negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input) - - if not self._is_prepared_input(cond_input) and negative_cond_input is None: - raise ValueError("`negative_cond_input` is required when cond_input does not already contains negative conditional input") - - if isinstance(cond_input, (list, tuple)): - - if not self.do_perturbed_attention_guidance: - return cond_input - - if len(negative_cond_input) != len(cond_input): - raise ValueError("The length of negative_cond_input and cond_input must be the same.") - - prepared_input = [] - for neg_cond, cond in zip(negative_cond_input, cond_input): - if neg_cond.shape[0] != cond.shape[0]: - raise ValueError("The batch size of negative_cond_input and cond_input must be the same.") - - cond = torch.cat([cond] * 2, dim=0) - if self.do_classifier_free_guidance: - prepared_input.append(torch.cat([neg_cond, cond], dim=0)) - else: - prepared_input.append(cond) - - return prepared_input - - elif isinstance(cond_input, torch.Tensor): - - if not self.do_perturbed_attention_guidance: - return cond_input - - cond_input = torch.cat([cond_input] * 2, dim=0) - if self.do_classifier_free_guidance: - return torch.cat([negative_cond_input, cond_input], dim=0) - else: - return cond_input - - else: - raise ValueError(f"Unsupported input type: {type(negative_cond_input)} and {type(cond_input)}") - - def apply_guidance( - self, - model_output: torch.Tensor, - timestep: int, - ) -> torch.Tensor: - - if not self.do_perturbed_attention_guidance: - return model_output - - if self.do_pag_adaptive_scaling: - pag_scale = max(self._pag_scale - self._pag_adaptive_scale * (1000 - timestep), 0) - else: - pag_scale = self._pag_scale - - if self.do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text, noise_pred_perturb = model_output.chunk(3) - noise_pred = ( - noise_pred_uncond - + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - + pag_scale * (noise_pred_text - noise_pred_perturb) - ) - else: - noise_pred_text, noise_pred_perturb = model_output.chunk(2) - noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb) - - if self.guidance_rescale > 0.0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) - - return noise_pred - - class PipelineBlockType(Enum): InputStep = 1 TextEncoderStep = 2 From 024a9f5de36e85d59882946126c9bf84c8bba4cf Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 27 Oct 2024 18:52:56 +0100 Subject: [PATCH 014/170] fix so that run_blocks can work with inputs in the state --- src/diffusers/pipelines/custom_pipeline_builder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/custom_pipeline_builder.py b/src/diffusers/pipelines/custom_pipeline_builder.py index 35e3e3da3eae..f07e61565761 100644 --- a/src/diffusers/pipelines/custom_pipeline_builder.py +++ b/src/diffusers/pipelines/custom_pipeline_builder.py @@ -2331,17 +2331,17 @@ def run_blocks(self, state: PipelineState = None, **kwargs): if name in input_params: state.add_intermediate(name, input_params.pop(name)) - # Add inputs to state, using defaults if not provided + # Add inputs to state, using defaults if not provided in the kwargs or the state + # if same input already in the state, will override it if provided in the kwargs for name, default in default_params.items(): if name in input_params: state.add_input(name, input_params.pop(name)) - else: + elif name not in state.inputs: state.add_input(name, default) # Warn about unexpected inputs if len(input_params) > 0: logger.warning(f"Unexpected input '{input_params.keys()}' provided. This input will be ignored.") - # Run the pipeline with torch.no_grad(): for block in self.pipeline_blocks: From 37e8dc7a5933a215b69413cc9fffb7f16fdd8a12 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 28 Oct 2024 00:37:48 +0100 Subject: [PATCH 015/170] remove img2img blocksgit status consolidate text2img and img2img --- .../pipelines/custom_pipeline_builder.py | 410 ++++++------------ 1 file changed, 143 insertions(+), 267 deletions(-) diff --git a/src/diffusers/pipelines/custom_pipeline_builder.py b/src/diffusers/pipelines/custom_pipeline_builder.py index f07e61565761..d97cc32a4643 100644 --- a/src/diffusers/pipelines/custom_pipeline_builder.py +++ b/src/diffusers/pipelines/custom_pipeline_builder.py @@ -1291,59 +1291,17 @@ def inputs(self) -> List[Tuple[str, Any]]: ("timesteps", None), ("sigmas", None), ("denoising_end", None), - ] - - @property - def intermediates_outputs(self) -> List[str]: - return ["timesteps", "num_inference_steps"] - - def __init__(self, scheduler=None): - super().__init__(scheduler=scheduler) - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - num_inference_steps = state.get_input("num_inference_steps") - timesteps = state.get_input("timesteps") - sigmas = state.get_input("sigmas") - denoising_end = state.get_input("denoising_end") - - device = pipeline._execution_device - - timesteps, num_inference_steps = retrieve_timesteps( - pipeline.scheduler, num_inference_steps, device, timesteps, sigmas - ) - - if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: - discrete_timestep_cutoff = int( - round( - pipeline.scheduler.config.num_train_timesteps - - (denoising_end * pipeline.scheduler.config.num_train_timesteps) - ) - ) - num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) - timesteps = timesteps[:num_inference_steps] - - state.add_intermediate("timesteps", timesteps) - state.add_intermediate("num_inference_steps", num_inference_steps) - - return pipeline, state - - -class Image2ImageSetTimestepsStep(PipelineBlock): - required_components = ["scheduler"] - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - ("num_inference_steps", 50), - ("timesteps", None), - ("sigmas", None), - ("denoising_end", None), + ("image", None), ("strength", 0.3), ("denoising_start", None), ("num_images_per_prompt", 1), + ("device", None), ] + @property + def intermediates_inputs(self) -> List[str]: + return ["batch_size"] + @property def intermediates_outputs(self) -> List[str]: return ["timesteps", "num_inference_steps", "latent_timestep"] @@ -1357,28 +1315,39 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: timesteps = state.get_input("timesteps") sigmas = state.get_input("sigmas") denoising_end = state.get_input("denoising_end") + device = state.get_input("device") + + # image to image only + image = state.get_input("image") # just to check if it is an image to image workflow strength = state.get_input("strength") denoising_start = state.get_input("denoising_start") num_images_per_prompt = state.get_input("num_images_per_prompt") + # image to image only batch_size = state.get_intermediate("batch_size") - device = pipeline._execution_device - - def denoising_value_valid(dnv): - return isinstance(dnv, float) and 0 < dnv < 1 + if device is None: + device = pipeline._execution_device timesteps, num_inference_steps = retrieve_timesteps( pipeline.scheduler, num_inference_steps, device, timesteps, sigmas ) - timesteps, num_inference_steps = pipeline.get_timesteps( - num_inference_steps, - strength, - device, - denoising_start=denoising_start if denoising_value_valid(denoising_start) else None, - ) - latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + if image is not None: + + def denoising_value_valid(dnv): + return isinstance(dnv, float) and 0 < dnv < 1 + + timesteps, num_inference_steps = pipeline.get_timesteps( + num_inference_steps, + strength, + device, + denoising_start=denoising_start if denoising_value_valid(denoising_start) else None, + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + else: + latent_timestep = None if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: discrete_timestep_cutoff = int( @@ -1397,74 +1366,9 @@ def denoising_value_valid(dnv): return pipeline, state -class Image2ImagePrepareLatentsStep(PipelineBlock): - required_components = ["vae"] - required_auxiliaries = ["image_processor"] - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - ("image", None), - ("num_images_per_prompt", 1), - ("generator", None), - ("latents", None), - ("device", None), - ("dtype", None), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return ["batch_size", "timesteps", "num_inference_steps"] - - @property - def intermediates_outputs(self) -> List[str]: - return ["latents", "timesteps", "num_inference_steps"] - - def __init__(self, vae=None, image_processor=None, vae_scale_factor=8): - if image_processor is None: - image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) - super().__init__(vae=vae, image_processor=image_processor, vae_scale_factor=vae_scale_factor) - - @torch.no_grad() - def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - image = state.get_input("image") - num_images_per_prompt = state.get_input("num_images_per_prompt") - generator = state.get_input("generator") - latents = state.get_input("latents") - denoising_start = state.get_input("denoising_start") - device = state.get_input("device") - dtype = state.get_input("dtype") - - # get intermediates - batch_size = state.get_intermediate("batch_size") - latent_timestep = state.get_intermediate("latent_timestep") - - device = pipeline._execution_device if device is None else device - dtype = pipeline.vae.dtype if dtype is None else dtype - - image = pipeline.image_processor.preprocess(image) - - add_noise = True if denoising_start is None else False - - if latents is None: - latents = pipeline.prepare_latents_img2img( - image, - latent_timestep, - batch_size, - num_images_per_prompt, - dtype, - device, - generator, - add_noise, - ) - - state.add_intermediate("latents", latents) - - return pipeline, state - - class PrepareLatentsStep(PipelineBlock): - required_components = ["scheduler"] + optional_components = ["vae", "scheduler"] + required_auxiliaries = ["image_processor"] @property def inputs(self) -> List[Tuple[str, Any]]: @@ -1476,22 +1380,36 @@ def inputs(self) -> List[Tuple[str, Any]]: ("num_images_per_prompt", 1), ("device", None), ("dtype", None), + ("image", None), + ("denoising_start", None), ] @property def intermediates_inputs(self) -> List[str]: - return ["batch_size"] + return ["batch_size", "latent_timestep"] @property def intermediates_outputs(self) -> List[str]: return ["latents"] - def __init__(self, scheduler=None): - super().__init__(scheduler=scheduler) + def __init__(self, vae=None, image_processor=None, vae_scale_factor=8, scheduler=None): + if image_processor is None: + image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) + super().__init__( + vae=vae, image_processor=image_processor, vae_scale_factor=vae_scale_factor, scheduler=scheduler + ) @staticmethod - def check_inputs(pipeline, height, width): - if height % pipeline.vae_scale_factor != 0 or width % pipeline.vae_scale_factor != 0: + def check_inputs(pipeline, height, width, image): + if image is not None and (height is not None or width is not None): + raise ValueError("Cannot specify both `image` and `height` or `width`") + + if ( + height is not None + and height % pipeline.vae_scale_factor != 0 + or width is not None + and width % pipeline.vae_scale_factor != 0 + ): raise ValueError( f"`height` and `width` have to be divisible by {pipeline.vae_scale_factor} but are {height} and {width}." ) @@ -1500,14 +1418,22 @@ def check_inputs(pipeline, height, width): def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: latents = state.get_input("latents") num_images_per_prompt = state.get_input("num_images_per_prompt") - height = state.get_input("height") - width = state.get_input("width") generator = state.get_input("generator") device = state.get_input("device") dtype = state.get_input("dtype") + # text to image only + height = state.get_input("height") + width = state.get_input("width") + + # image to image only + image = state.get_input("image") + denoising_start = state.get_input("denoising_start") + batch_size = state.get_intermediate("batch_size") prompt_embeds = state.get_intermediate("prompt_embeds", None) + # image to image only + latent_timestep = state.get_intermediate("latent_timestep", None) if dtype is None and prompt_embeds is not None: dtype = prompt_embeds.dtype @@ -1517,24 +1443,36 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin if device is None: device = pipeline._execution_device - height = height or pipeline.default_sample_size * pipeline.vae_scale_factor - width = width or pipeline.default_sample_size * pipeline.vae_scale_factor - - self.check_inputs(pipeline, height, width) - - # 5. Prepare latent variables - - num_channels_latents = pipeline.num_channels_latents - latents = pipeline.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - dtype, - device, - generator, - latents, - ) + self.check_inputs(pipeline, height, width, image) + + if image is None: + height = height or pipeline.default_sample_size * pipeline.vae_scale_factor + width = width or pipeline.default_sample_size * pipeline.vae_scale_factor + num_channels_latents = pipeline.num_channels_latents + latents = pipeline.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents, + ) + else: + image = pipeline.image_processor.preprocess(image) + add_noise = True if denoising_start is None else False + if latents is None: + latents = pipeline.prepare_latents_img2img( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + dtype, + device, + generator, + add_noise, + ) state.add_intermediate("latents", latents) @@ -1555,108 +1493,15 @@ def inputs(self) -> List[Tuple[str, Any]]: ("negative_crops_coords_top_left", (0, 0)), ("num_images_per_prompt", 1), ("guidance_scale", 5.0), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return ["latents", "batch_size", "pooled_prompt_embeds"] - - @property - def intermediates_outputs(self) -> List[str]: - return ["add_time_ids", "negative_add_time_ids", "timestep_cond"] - - def __init__(self, unet=None): - super().__init__(unet=unet) - - @torch.no_grad() - def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - original_size = state.get_input("original_size") - target_size = state.get_input("target_size") - negative_original_size = state.get_input("negative_original_size") - negative_target_size = state.get_input("negative_target_size") - crops_coords_top_left = state.get_input("crops_coords_top_left") - negative_crops_coords_top_left = state.get_input("negative_crops_coords_top_left") - num_images_per_prompt = state.get_input("num_images_per_prompt") - guidance_scale = state.get_input("guidance_scale") - - latents = state.get_intermediate("latents") - batch_size = state.get_intermediate("batch_size") - pooled_prompt_embeds = state.get_intermediate("pooled_prompt_embeds") - - device = pipeline._execution_device - - height, width = latents.shape[-2:] - height = height * pipeline.vae_scale_factor - width = width * pipeline.vae_scale_factor - - original_size = original_size or (height, width) - target_size = target_size or (height, width) - - if hasattr(pipeline, "text_encoder_2") and pipeline.text_encoder_2 is not None: - text_encoder_projection_dim = pipeline.text_encoder_2.config.projection_dim - else: - text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) - - add_time_ids = pipeline._get_add_time_ids( - original_size, - crops_coords_top_left, - target_size, - pooled_prompt_embeds.dtype, - text_encoder_projection_dim=text_encoder_projection_dim, - ) - add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1).to(device=device) - - if negative_original_size is not None and negative_target_size is not None: - negative_add_time_ids = pipeline._get_add_time_ids( - negative_original_size, - negative_crops_coords_top_left, - negative_target_size, - pooled_prompt_embeds.dtype, - text_encoder_projection_dim=text_encoder_projection_dim, - ) - else: - negative_add_time_ids = add_time_ids - negative_add_time_ids = negative_add_time_ids.repeat(batch_size * num_images_per_prompt, 1).to(device=device) - - # Optionally get Guidance Scale Embedding for LCM - timestep_cond = None - if ( - hasattr(pipeline, "unet") - and pipeline.unet is not None - and pipeline.unet.config.time_cond_proj_dim is not None - ): - guidance_scale_tensor = torch.tensor(guidance_scale - 1).repeat(batch_size * num_images_per_prompt) - timestep_cond = pipeline.get_guidance_scale_embedding( - guidance_scale_tensor, embedding_dim=pipeline.unet.config.time_cond_proj_dim - ).to(device=device, dtype=latents.dtype) - - state.add_intermediate("add_time_ids", add_time_ids) - state.add_intermediate("negative_add_time_ids", negative_add_time_ids) - state.add_intermediate("timestep_cond", timestep_cond) - return pipeline, state - - -class Image2ImagePrepareAdditionalConditioningStep(PipelineBlock): - required_components = ["unet"] - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - ("original_sizife", None), - ("target_size", None), - ("negative_original_size", None), - ("negative_target_size", None), - ("crops_coords_top_left", (0, 0)), - ("negative_crops_coords_top_left", (0, 0)), - ("num_images_per_prompt", 1), - ("guidance_scale", 5.0), ("aesthetic_score", 6.0), ("negative_aesthetic_score", 2.0), + ("device", None), + ("image", None), ] @property def intermediates_inputs(self) -> List[str]: - return ["latents"] + return ["latents", "batch_size", "pooled_prompt_embeds"] @property def intermediates_outputs(self) -> List[str]: @@ -1675,6 +1520,10 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin negative_crops_coords_top_left = state.get_input("negative_crops_coords_top_left") num_images_per_prompt = state.get_input("num_images_per_prompt") guidance_scale = state.get_input("guidance_scale") + device = state.get_input("device") + + # image to image only + image = state.get_input("image") aesthetic_score = state.get_input("aesthetic_score") negative_aesthetic_score = state.get_input("negative_aesthetic_score") @@ -1682,7 +1531,8 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin batch_size = state.get_intermediate("batch_size") pooled_prompt_embeds = state.get_intermediate("pooled_prompt_embeds") - device = pipeline._execution_device + if device is None: + device = pipeline._execution_device height, width = latents.shape[-2:] height = height * pipeline.vae_scale_factor @@ -1691,30 +1541,56 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin original_size = original_size or (height, width) target_size = target_size or (height, width) - if negative_original_size is None: - negative_original_size = original_size - if negative_target_size is None: - negative_target_size = target_size - if hasattr(pipeline, "text_encoder_2") and pipeline.text_encoder_2 is not None: text_encoder_projection_dim = pipeline.text_encoder_2.config.projection_dim else: text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) - add_time_ids, negative_add_time_ids = pipeline._get_add_time_ids_img2img( - original_size, - crops_coords_top_left, - target_size, - aesthetic_score, - negative_aesthetic_score, - negative_original_size, - negative_crops_coords_top_left, - negative_target_size, - dtype=pooled_prompt_embeds.dtype, - text_encoder_projection_dim=text_encoder_projection_dim, - ) - add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1).to(device=device) - negative_add_time_ids = negative_add_time_ids.repeat(batch_size * num_images_per_prompt, 1).to(device=device) + if image is None: + add_time_ids = pipeline._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + pooled_prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1).to(device=device) + + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = pipeline._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + pooled_prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + negative_add_time_ids = negative_add_time_ids.repeat(batch_size * num_images_per_prompt, 1).to( + device=device + ) + else: + if negative_original_size is None: + negative_original_size = original_size + if negative_target_size is None: + negative_target_size = target_size + + add_time_ids, negative_add_time_ids = pipeline._get_add_time_ids_img2img( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=pooled_prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1).to(device=device) + negative_add_time_ids = negative_add_time_ids.repeat(batch_size * num_images_per_prompt, 1).to( + device=device + ) # Optionally get Guidance Scale Embedding for LCM timestep_cond = None From 8b811feece161bb21f9603b2ad99ea0b77f0b601 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 30 Oct 2024 10:13:03 +0100 Subject: [PATCH 016/170] refactor, from_pretrained, from_pipe, remove_blocks, replace_blocks --- src/diffusers/__init__.py | 4 + src/diffusers/pipelines/__init__.py | 4 + src/diffusers/pipelines/auto_pipeline.py | 13 +- .../pipelines/modular_pipeline_builder.py | 939 +++++ .../pipelines/pipeline_loading_utils.py | 2 +- .../pipelines/stable_diffusion_xl/__init__.py | 22 + .../pipeline_stable_diffusion_xl_modular.py} | 3372 +++++++---------- 7 files changed, 2411 insertions(+), 1945 deletions(-) create mode 100644 src/diffusers/pipelines/modular_pipeline_builder.py rename src/diffusers/pipelines/{custom_pipeline_builder.py => stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py} (78%) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 019d744730ab..e2285f548c2f 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -145,6 +145,7 @@ "DDIMPipeline", "DDPMPipeline", "DiffusionPipeline", + "ModularPipelineBuilder", "DiTPipeline", "ImagePipelineOutput", "KarrasVePipeline", @@ -369,6 +370,7 @@ "StableDiffusionXLPAGInpaintPipeline", "StableDiffusionXLPAGPipeline", "StableDiffusionXLPipeline", + "StableDiffusionXLModularPipeline", "StableUnCLIPImg2ImgPipeline", "StableUnCLIPPipeline", "StableVideoDiffusionPipeline", @@ -626,6 +628,7 @@ KarrasVePipeline, LDMPipeline, LDMSuperResolutionPipeline, + ModularPipelineBuilder, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline, @@ -819,6 +822,7 @@ StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLInstructPix2PixPipeline, + StableDiffusionXLModularPipeline, StableDiffusionXLPAGImg2ImgPipeline, StableDiffusionXLPAGInpaintPipeline, StableDiffusionXLPAGPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index d7ff34310beb..807829e23728 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -46,6 +46,7 @@ "AutoPipelineForInpainting", "AutoPipelineForText2Image", ] + _import_structure["modular_pipeline_builder"] = ["ModularPipelineBuilder"] _import_structure["consistency_models"] = ["ConsistencyModelPipeline"] _import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"] _import_structure["ddim"] = ["DDIMPipeline"] @@ -296,6 +297,7 @@ "StableDiffusionXLInpaintPipeline", "StableDiffusionXLInstructPix2PixPipeline", "StableDiffusionXLPipeline", + "StableDiffusionXLModularPipeline", ] ) _import_structure["stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"] @@ -432,6 +434,7 @@ from .deprecated import KarrasVePipeline, LDMPipeline, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline from .dit import DiTPipeline from .latent_diffusion import LDMSuperResolutionPipeline + from .modular_pipeline_builder import ModularPipelineBuilder from .pipeline_utils import ( AudioPipelineOutput, DiffusionPipeline, @@ -620,6 +623,7 @@ StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLInstructPix2PixPipeline, + StableDiffusionXLModularPipeline, StableDiffusionXLPipeline, ) from .stable_video_diffusion import StableVideoDiffusionPipeline diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 0214d7dd6f3c..ac52024c7412 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -214,14 +214,15 @@ def _get_connected_pipeline(pipeline_cls): return _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, pipeline_cls.__name__, throw_error_if_not_exist=False) +def _get_model(pipeline_class_name): + for task_mapping in SUPPORTED_TASKS_MAPPINGS: + for model_name, pipeline in task_mapping.items(): + if pipeline.__name__ == pipeline_class_name: + return model_name + def _get_task_class(mapping, pipeline_class_name, throw_error_if_not_exist: bool = True): - def get_model(pipeline_class_name): - for task_mapping in SUPPORTED_TASKS_MAPPINGS: - for model_name, pipeline in task_mapping.items(): - if pipeline.__name__ == pipeline_class_name: - return model_name - model_name = get_model(pipeline_class_name) + model_name = _get_model(pipeline_class_name) if model_name is not None: task_class = mapping.get(model_name, None) diff --git a/src/diffusers/pipelines/modular_pipeline_builder.py b/src/diffusers/pipelines/modular_pipeline_builder.py new file mode 100644 index 000000000000..98c2a9139e44 --- /dev/null +++ b/src/diffusers/pipelines/modular_pipeline_builder.py @@ -0,0 +1,939 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple, Union +import importlib +from collections import OrderedDict +import PIL +import torch +from tqdm.auto import tqdm + +from ..configuration_utils import ConfigMixin +from ..loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin +from ..models import ImageProjection +from ..models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor +from ..models.lora import adjust_lora_scale_text_encoder +from ..utils import ( + USE_PEFT_BACKEND, + is_accelerate_available, + is_accelerate_version, + logging, + scale_lora_layers, + unscale_lora_layers, +) +from ..utils.hub_utils import validate_hf_hub_args +from ..utils.torch_utils import randn_tensor +from .pipeline_loading_utils import _fetch_class_library_tuple, _get_pipeline_class +from .pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from .auto_pipeline import _get_model + +if is_accelerate_available(): + import accelerate + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +MODULAR_PIPELINE_MAPPING = { + "stable-diffusion-xl": "StableDiffusionXLModularPipeline", +} + + + +@dataclass +class PipelineState: + """ + [`PipelineState`] stores the state of a pipeline. It is used to pass data between pipeline blocks. + """ + + inputs: Dict[str, Any] = field(default_factory=dict) + intermediates: Dict[str, Any] = field(default_factory=dict) + outputs: Dict[str, Any] = field(default_factory=dict) + + def add_input(self, key: str, value: Any): + self.inputs[key] = value + + def add_intermediate(self, key: str, value: Any): + self.intermediates[key] = value + + def add_output(self, key: str, value: Any): + self.outputs[key] = value + + def get_input(self, key: str, default: Any = None) -> Any: + return self.inputs.get(key, default) + + def get_intermediate(self, key: str, default: Any = None) -> Any: + return self.intermediates.get(key, default) + + def get_output(self, key: str, default: Any = None) -> Any: + return self.outputs.get(key, default) + + def to_dict(self) -> Dict[str, Any]: + return {**self.__dict__, "inputs": self.inputs, "intermediates": self.intermediates, "outputs": self.outputs} + + def __repr__(self): + def format_value(v): + if hasattr(v, "shape") and hasattr(v, "dtype"): + return f"Tensor(\n dtype={v.dtype}, shape={v.shape}\n {v})" + elif isinstance(v, list) and len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): + return f"[Tensor(\n dtype={v[0].dtype}, shape={v[0].shape}\n {v[0]}), ...]" + else: + return repr(v) + + inputs = "\n".join(f" {k}: {format_value(v)}" for k, v in self.inputs.items()) + intermediates = "\n".join(f" {k}: {format_value(v)}" for k, v in self.intermediates.items()) + outputs = "\n".join(f" {k}: {format_value(v)}" for k, v in self.outputs.items()) + + return ( + f"PipelineState(\n" + f" inputs={{\n{inputs}\n }},\n" + f" intermediates={{\n{intermediates}\n }},\n" + f" outputs={{\n{outputs}\n }}\n" + f")" + ) + + +class PipelineBlock: + optional_components = [] + required_components = [] + required_auxiliaries = [] + + @property + def inputs(self) -> Tuple[Tuple[str, Any], ...]: + # (input_name, default_value) + return () + + @property + def intermediates_inputs(self) -> List[str]: + return [] + + @property + def intermediates_outputs(self) -> List[str]: + return [] + + def __init__(self, **kwargs): + self.components: Dict[str, Any] = {} + self.auxiliaries: Dict[str, Any] = {} + self.configs: Dict[str, Any] = {} + + # Process kwargs + for key, value in kwargs.items(): + if key in self.required_components or key in self.optional_components: + self.components[key] = value + elif key in self.required_auxiliaries: + self.auxiliaries[key] = value + else: + self.configs[key] = value + + @classmethod + def from_pipe(cls, pipe: DiffusionPipeline, **kwargs): + """ + Create a PipelineBlock instance from a diffusion pipeline object. + + Args: + pipe: A `[DiffusionPipeline]` object. + + Returns: + PipelineBlock: An instance initialized with the pipeline's components and configurations. + """ + # add components + expected_components = set(cls.required_components + cls.optional_components) + # - components that are passed in kwargs + components_to_add = { + component_name: kwargs.pop(component_name) + for component_name in expected_components + if component_name in kwargs + } + # - components that are in the pipeline + for component_name, component in pipe.components.items(): + if component_name in expected_components and component_name not in components_to_add: + components_to_add[component_name] = component + + # add auxiliaries + # - auxiliaries that are passed in kwargs + auxiliaries_to_add = {k: kwargs.pop(k) for k in cls.required_auxiliaries if k in kwargs} + # - auxiliaries that are in the pipeline + for aux_name in cls.required_auxiliaries: + if hasattr(pipe, aux_name) and aux_name not in auxiliaries_to_add: + auxiliaries_to_add[aux_name] = getattr(pipe, aux_name) + block_kwargs = {**components_to_add, **auxiliaries_to_add} + + # add pipeline configs + init_params = inspect.signature(cls.__init__).parameters + # modules info are also registered in the config as tuples, e.g. {'tokenizer': ('transformers', 'CLIPTokenizer')} + # we need to exclude them for block_kwargs otherwise it will override the actual module + expected_configs = { + k + for k in pipe.config.keys() + if k in init_params and k not in expected_components and k not in cls.required_auxiliaries + } + + for config_name in expected_configs: + if config_name not in block_kwargs: + if config_name in kwargs: + # - configs that are passed in kwargs + block_kwargs[config_name] = kwargs.pop(config_name) + else: + # - configs that are in the pipeline + block_kwargs[config_name] = pipe.config[config_name] + + # Add any remaining relevant pipeline attributes + for attr_name in dir(pipe): + if attr_name not in block_kwargs and attr_name in init_params: + block_kwargs[attr_name] = getattr(pipe, attr_name) + + return cls(**block_kwargs) + + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + raise NotImplementedError("__call__ method must be implemented in subclasses") + + def __repr__(self): + class_name = self.__class__.__name__ + components = ", ".join(f"{k}={type(v).__name__}" for k, v in self.components.items()) + auxiliaries = ", ".join(f"{k}={type(v).__name__}" for k, v in self.auxiliaries.items()) + configs = ", ".join(f"{k}={v}" for k, v in self.configs.items()) + inputs = ", ".join(f"{name}={default}" for name, default in self.inputs) + intermediates_inputs = ", ".join(self.intermediates_inputs) + intermediates_outputs = ", ".join(self.intermediates_outputs) + + return ( + f"{class_name}(\n" + f" components: {components}\n" + f" auxiliaries: {auxiliaries}\n" + f" configs: {configs}\n" + f" inputs: {inputs}\n" + f" intermediates_inputs: {intermediates_inputs}\n" + f" intermediates_outputs: {intermediates_outputs}\n" + f")" + ) + + +class ModularPipelineBuilder(ConfigMixin): + """ + Base class for all Modular pipelines. + + """ + config_name = "model_index.json" + model_cpu_offload_seq = None + hf_device_map = None + _exclude_from_cpu_offload = [] + default_pipeline_blocks = [] + + def __init__(self): + super().__init__() + self.register_to_config() + self.pipeline_blocks = [] + + # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline.register_modules + def register_modules(self, **kwargs): + for name, module in kwargs.items(): + # retrieve library + if module is None or isinstance(module, (tuple, list)) and module[0] is None: + register_dict = {name: (None, None)} + else: + library, class_name = _fetch_class_library_tuple(module) + register_dict = {name: (library, class_name)} + + # save model index config + self.register_to_config(**register_dict) + + # set models + setattr(self, name, module) + + @property + def device(self) -> torch.device: + r""" + Returns: + `torch.device`: The torch device on which the pipeline is located. + """ + modules = self.components.values() + modules = [m for m in modules if isinstance(m, torch.nn.Module)] + + for module in modules: + return module.device + + return torch.device("cpu") + + @property + # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + [`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from + Accelerate's module hooks. + """ + for name, model in self.components.items(): + if not isinstance(model, torch.nn.Module) or name in self._exclude_from_cpu_offload: + continue + + if not hasattr(model, "_hf_hook"): + return self.device + for module in model.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + @property + def dtype(self) -> torch.dtype: + r""" + Returns: + `torch.dtype`: The torch dtype on which the pipeline is located. + """ + modules = self.components.values() + modules = [m for m in modules if isinstance(m, torch.nn.Module)] + + for module in modules: + return module.dtype + + return torch.float32 + + @property + def components(self) -> Dict[str, Any]: + r""" + The `self.components` property returns all modules needed to initialize the pipeline, as defined by the + pipeline blocks. + + Returns (`dict`): + A dictionary containing all the components defined in the pipeline blocks. + """ + + expected_components = set() + for block in self.pipeline_blocks: + expected_components.update(block.components.keys()) + + components = {} + for name in expected_components: + if hasattr(self, name): + components[name] = getattr(self, name) + + return components + + @property + def auxiliaries(self) -> Dict[str, Any]: + r""" + The `self.auxiliaries` property returns all auxiliaries needed to initialize the pipeline, as defined by the + pipeline blocks. + + Returns (`dict`): + A dictionary containing all the auxiliaries defined in the pipeline blocks. + """ + # First collect all expected auxiliary names from blocks + expected_auxiliaries = set() + for block in self.pipeline_blocks: + expected_auxiliaries.update(block.auxiliaries.keys()) + + # Then fetch the actual auxiliaries from the pipeline + auxiliaries = {} + for name in expected_auxiliaries: + if hasattr(self, name): + auxiliaries[name] = getattr(self, name) + + return auxiliaries + + @property + def configs(self) -> Dict[str, Any]: + r""" + The `self.configs` property returns all configs needed to initialize the pipeline, as defined by the + pipeline blocks. + + Returns (`dict`): + A dictionary containing all the configs defined in the pipeline blocks. + """ + # First collect all expected config names from blocks + expected_configs = set() + for block in self.pipeline_blocks: + expected_configs.update(block.configs.keys()) + + # Then fetch the actual configs from the pipeline's config + configs = {} + for name in expected_configs: + if name in self.config: + configs[name] = self.config[name] + + return configs + + # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline.progress_bar + def progress_bar(self, iterable=None, total=None): + if not hasattr(self, "_progress_bar_config"): + self._progress_bar_config = {} + elif not isinstance(self._progress_bar_config, dict): + raise ValueError( + f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." + ) + + if iterable is not None: + return tqdm(iterable, **self._progress_bar_config) + elif total is not None: + return tqdm(total=total, **self._progress_bar_config) + else: + raise ValueError("Either `total` or `iterable` has to be defined.") + + # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline.set_progress_bar_config + def set_progress_bar_config(self, **kwargs): + self._progress_bar_config = kwargs + + def __call__(self, *args, **kwargs): + raise NotImplementedError("__call__ is not implemented for ModularPipelineBuilder") + + def remove_blocks(self, indices: Union[int, List[int]]): + """ + Remove one or more blocks from the pipeline by their indices and clean up associated components, + configs, and auxiliaries that are no longer needed by remaining blocks. + + Args: + indices (Union[int, List[int]]): The index or list of indices of blocks to remove + """ + # Convert single index to list + indices = [indices] if isinstance(indices, int) else indices + + # Validate indices + for idx in indices: + if not 0 <= idx < len(self.pipeline_blocks): + raise ValueError(f"Invalid block index {idx}. Index must be between 0 and {len(self.pipeline_blocks) - 1}") + + # Sort indices in descending order to avoid shifting issues when removing + indices = sorted(indices, reverse=True) + + # Store blocks to be removed + blocks_to_remove = [self.pipeline_blocks[idx] for idx in indices] + + # Remove blocks from pipeline + for idx in indices: + self.pipeline_blocks.pop(idx) + + + # Consolidate items to remove from all blocks + components_to_remove = {k: v for block in blocks_to_remove for k, v in block.components.items()} + auxiliaries_to_remove = {k: v for block in blocks_to_remove for k, v in block.auxiliaries.items()} + configs_to_remove = {k: v for block in blocks_to_remove for k, v in block.configs.items()} + + # The properties will now reflect only the remaining blocks + remaining_components = self.components + remaining_auxiliaries = self.auxiliaries + remaining_configs = self.configs + + # Clean up all items that are no longer needed + for component_name in components_to_remove: + if component_name not in remaining_components: + if component_name in self.config: + del self.config[component_name] + if hasattr(self, component_name): + delattr(self, component_name) + + for auxiliary_name in auxiliaries_to_remove: + if auxiliary_name not in remaining_auxiliaries: + if hasattr(self, auxiliary_name): + delattr(self, auxiliary_name) + + for config_name in configs_to_remove: + if config_name not in remaining_configs: + if config_name in self.config: + del self.config[config_name] + + def add_blocks(self, pipeline_blocks, at: int = -1): + """Add blocks to the pipeline. + + Args: + pipeline_blocks: A single PipelineBlock instance or a list of PipelineBlock instances. + at (int, optional): Index at which to insert the blocks. Defaults to -1 (append at end). + """ + # Convert single block to list for uniform processing + if not isinstance(pipeline_blocks, (list, tuple)): + pipeline_blocks = [pipeline_blocks] + + # Validate insert_at index + if at != -1 and not 0 <= at <= len(self.pipeline_blocks): + raise ValueError(f"Invalid at index {at}. Index must be between 0 and {len(self.pipeline_blocks)}") + + # Consolidate all items from blocks + components_to_add = {} + configs_to_add = {} + auxiliaries_to_add = {} + + # Add blocks in order + for i, block in enumerate(pipeline_blocks): + # Add block to pipeline at specified position + if at == -1: + self.pipeline_blocks.append(block) + else: + self.pipeline_blocks.insert(at + i, block) + + # Collect components that don't already exist + for k, v in block.components.items(): + if not hasattr(self, k) or (getattr(self, k, None) is None and v is not None): + components_to_add[k] = v + + # Collect configs and auxiliaries + configs_to_add.update(block.configs) + auxiliaries_to_add.update(block.auxiliaries) + + # Validate all required components and auxiliaries after consolidation + for block in pipeline_blocks: + for required_component in block.required_components: + if ( + not hasattr(self, required_component) + and required_component not in components_to_add + or getattr(self, required_component, None) is None + and components_to_add.get(required_component) is None + ): + raise ValueError( + f"Cannot add block {block.__class__.__name__}: Required component {required_component} not found in pipeline" + ) + + for required_auxiliary in block.required_auxiliaries: + if ( + not hasattr(self, required_auxiliary) + and required_auxiliary not in auxiliaries_to_add + or getattr(self, required_auxiliary, None) is None + and auxiliaries_to_add.get(required_auxiliary) is None + ): + raise ValueError( + f"Cannot add block {block.__class__.__name__}: Required auxiliary {required_auxiliary} not found in pipeline" + ) + + # Process all items in batches + if components_to_add: + self.register_modules(**components_to_add) + if configs_to_add: + self.register_to_config(**configs_to_add) + for key, value in auxiliaries_to_add.items(): + + setattr(self, key, value) + + def replace_blocks(self, pipeline_blocks, at: int): + """Replace one or more blocks in the pipeline at the specified index. + + Args: + pipeline_blocks: A single PipelineBlock instance or a list of PipelineBlock instances + that will replace existing blocks. + at (int): Index at which to replace the blocks. + """ + # Convert single block to list for uniform processing + if not isinstance(pipeline_blocks, (list, tuple)): + pipeline_blocks = [pipeline_blocks] + + # Validate replace_at index + if not 0 <= at < len(self.pipeline_blocks): + raise ValueError( + f"Invalid at index {at}. Index must be between 0 and {len(self.pipeline_blocks) - 1}" + ) + + # Add new blocks first + self.add_blocks(pipeline_blocks, at=at) + + # Calculate indices to remove + # We need to remove the original blocks that are now shifted by the length of pipeline_blocks + indices_to_remove = list(range( + at + len(pipeline_blocks), + at + len(pipeline_blocks) * 2 + )) + + # Remove the old blocks + self.remove_blocks(indices_to_remove) + + @classmethod + @validate_hf_hub_args + def from_pretrained(cls, pretrained_model_or_path, **kwargs): + + # (1) create the base pipeline + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + token = kwargs.pop("token", None) + local_files_only = kwargs.pop("local_files_only", False) + revision = kwargs.pop("revision", None) + + load_config_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "token": token, + "local_files_only": local_files_only, + "revision": revision, + } + + config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) + base_pipeline_class_name = config["_class_name"] + base_pipeline_class = _get_pipeline_class(cls, config) + + kwargs = {**load_config_kwargs, **kwargs} + base_pipeline = base_pipeline_class.from_pretrained(pretrained_model_or_path, **kwargs) + + # (2) map the base pipeline to pipeline blocks + modular_pipeline_class_name = MODULAR_PIPELINE_MAPPING[_get_model(base_pipeline_class_name)] + modular_pipeline_class = _get_pipeline_class(cls, config=None, class_name=modular_pipeline_class_name) + + + # (3) create the pipeline blocks + pipeline_blocks = [ + block_class.from_pipe(base_pipeline) + for block_class in modular_pipeline_class.default_pipeline_blocks + ] + + # (4) create the builder + builder = modular_pipeline_class() + builder.add_blocks(pipeline_blocks) + + return builder + + @classmethod + def from_pipe(cls, pipeline, **kwargs): + base_pipeline_class_name = pipeline.__class__.__name__ + modular_pipeline_class_name = MODULAR_PIPELINE_MAPPING[_get_model(base_pipeline_class_name)] + modular_pipeline_class = _get_pipeline_class(cls, config=None, class_name=modular_pipeline_class_name) + + pipeline_blocks = [] + # Create each block, passing only unused items that the block expects + for block_class in modular_pipeline_class.default_pipeline_blocks: + expected_components = set(block_class.required_components + block_class.optional_components) + expected_auxiliaries = set(block_class.required_auxiliaries) + + # Get init parameters to check for expected configs + init_params = inspect.signature(block_class.__init__).parameters + expected_configs = { + k for k in init_params + if k not in expected_components + and k not in expected_auxiliaries + } + + block_kwargs = {} + + for key, value in kwargs.items(): + if (key in expected_components or + key in expected_auxiliaries or + key in expected_configs): + block_kwargs[key] = value + + # Create the block with filtered kwargs + block = block_class.from_pipe(pipeline, **block_kwargs) + pipeline_blocks.append(block) + + # Create and setup the builder + builder = modular_pipeline_class() + builder.add_blocks(pipeline_blocks) + + # Warn about unused kwargs + unused_kwargs = { + k: v for k, v in kwargs.items() + if not any( + k in block.components or k in block.auxiliaries or k in block.configs + for block in pipeline_blocks + ) + } + if unused_kwargs: + logger.warning( + f"The following items were passed but not used by any pipeline block: {list(unused_kwargs.keys())}" + ) + + return builder + + def run_blocks(self, state: PipelineState = None, **kwargs): + """ + Run one or more blocks in sequence, optionally you can pass a previous pipeline state. + """ + if state is None: + state = PipelineState() + + # Make a copy of the input kwargs + input_params = kwargs.copy() + + default_params = self.default_call_parameters + + # user can pass the intermediate of the first block + for name in self.pipeline_blocks[0].intermediates_inputs: + if name in input_params: + state.add_intermediate(name, input_params.pop(name)) + + # Add inputs to state, using defaults if not provided in the kwargs or the state + # if same input already in the state, will override it if provided in the kwargs + for name, default in default_params.items(): + if name in input_params: + state.add_input(name, input_params.pop(name)) + elif name not in state.inputs: + state.add_input(name, default) + + # Warn about unexpected inputs + if len(input_params) > 0: + logger.warning(f"Unexpected input '{input_params.keys()}' provided. This input will be ignored.") + # Run the pipeline + with torch.no_grad(): + for block in self.pipeline_blocks: + try: + pipeline, state = block(self, state) + except Exception: + error_msg = f"Error in block: ({block.__class__.__name__}):\n" + logger.error(error_msg) + raise + + return state + + def run_pipeline(self, **kwargs): + state = PipelineState() + + # Make a copy of the input kwargs + input_params = kwargs.copy() + + default_params = self.default_call_parameters + + # Add inputs to state, using defaults if not provided + for name, default in default_params.items(): + if name in input_params: + state.add_input(name, input_params.pop(name)) + else: + state.add_input(name, default) + + # Warn about unexpected inputs + if len(input_params) > 0: + logger.warning(f"Unexpected input '{input_params.keys()}' provided. This input will be ignored.") + + # Run the pipeline + with torch.no_grad(): + for block in self.pipeline_blocks: + try: + pipeline, state = block(self, state) + except Exception: + error_msg = f"Error in block: ({block.__class__.__name__}):\n" + logger.error(error_msg) + raise + + return state.get_output("images") + + @property + def default_call_parameters(self) -> Dict[str, Any]: + params = {} + for block in self.pipeline_blocks: + for name, default in block.inputs: + if name not in params: + params[name] = default + return params + + def __repr__(self): + output = "CustomPipeline Configuration:\n" + output += "==============================\n\n" + + # List the blocks used to build the pipeline + output += "Pipeline Blocks:\n" + output += "----------------\n" + for i, block in enumerate(self.pipeline_blocks): + output += f"{i}. {block.__class__.__name__}\n" + + intermediates_str = "" + if hasattr(block, "intermediates_inputs"): + intermediates_str += f"{', '.join(block.intermediates_inputs)}" + + if hasattr(block, "intermediates_outputs"): + if intermediates_str: + intermediates_str += " -> " + else: + intermediates_str += "-> " + intermediates_str += f"{', '.join(block.intermediates_outputs)}" + + if intermediates_str: + output += f" {intermediates_str}\n" + + output += "\n" + output += "\n" + + # List the components registered in the pipeline + output += "Registered Components:\n" + output += "----------------------\n" + for name, component in self.components.items(): + output += f"{name}: {type(component).__name__}" + if hasattr(component, "dtype") and hasattr(component, "device"): + output += f" (dtype={component.dtype}, device={component.device})" + output += "\n" + output += "\n" + + # List the auxiliaries registered in the pipeline + output += "Registered Auxiliaries:\n" + output += "----------------------\n" + for name, auxiliary in self.auxiliaries.items(): + output += f"{name}: {type(auxiliary).__name__}\n" + output += "\n" + + # List the configs registered in the pipeline + output += "Registered Configs:\n" + output += "------------------\n" + for name, config in self.configs.items(): + output += f"{name}: {config!r}\n" + output += "\n" + + + # List the default call parameters + output += "Default Call Parameters:\n" + output += "------------------------\n" + params = self.default_call_parameters + for name, default in params.items(): + output += f"{name}: {default!r}\n" + + # Add a section for required call parameters: + # intermediate inputs for the first block + output += "\nRequired Call Parameters:\n" + output += "--------------------------\n" + for name in self.pipeline_blocks[0].intermediates_inputs: + output += f"{name}: \n" + params[name] = "" + + output += "\nNote: These are the default values. Actual values may be different when running the pipeline." + return output + + # YiYi TO-DO: try to unify the to method with the one in DiffusionPipeline + # Modified from diffusers.pipelines.pipeline_utils.DiffusionPipeline.to + def to(self, *args, **kwargs): + r""" + Performs Pipeline dtype and/or device conversion. A torch.dtype and torch.device are inferred from the + arguments of `self.to(*args, **kwargs).` + + + + If the pipeline already has the correct torch.dtype and torch.device, then it is returned as is. Otherwise, + the returned pipeline is a copy of self with the desired torch.dtype and torch.device. + + + + + Here are the ways to call `to`: + + - `to(dtype, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified + [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype) + - `to(device, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified + [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) + - `to(device=None, dtype=None, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the + specified [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) and + [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype) + + Arguments: + dtype (`torch.dtype`, *optional*): + Returns a pipeline with the specified + [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype) + device (`torch.Device`, *optional*): + Returns a pipeline with the specified + [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) + silence_dtype_warnings (`str`, *optional*, defaults to `False`): + Whether to omit warnings if the target `dtype` is not compatible with the target `device`. + + Returns: + [`DiffusionPipeline`]: The pipeline converted to specified `dtype` and/or `dtype`. + """ + dtype = kwargs.pop("dtype", None) + device = kwargs.pop("device", None) + silence_dtype_warnings = kwargs.pop("silence_dtype_warnings", False) + + dtype_arg = None + device_arg = None + if len(args) == 1: + if isinstance(args[0], torch.dtype): + dtype_arg = args[0] + else: + device_arg = torch.device(args[0]) if args[0] is not None else None + elif len(args) == 2: + if isinstance(args[0], torch.dtype): + raise ValueError( + "When passing two arguments, make sure the first corresponds to `device` and the second to `dtype`." + ) + device_arg = torch.device(args[0]) if args[0] is not None else None + dtype_arg = args[1] + elif len(args) > 2: + raise ValueError("Please make sure to pass at most two arguments (`device` and `dtype`) `.to(...)`") + + if dtype is not None and dtype_arg is not None: + raise ValueError( + "You have passed `dtype` both as an argument and as a keyword argument. Please only pass one of the two." + ) + + dtype = dtype or dtype_arg + + if device is not None and device_arg is not None: + raise ValueError( + "You have passed `device` both as an argument and as a keyword argument. Please only pass one of the two." + ) + + device = device or device_arg + + # throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU. + def module_is_sequentially_offloaded(module): + if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"): + return False + + return hasattr(module, "_hf_hook") and ( + isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook) + or hasattr(module._hf_hook, "hooks") + and isinstance(module._hf_hook.hooks[0], accelerate.hooks.AlignDevicesHook) + ) + + def module_is_offloaded(module): + if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"): + return False + + return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload) + + # .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer + pipeline_is_sequentially_offloaded = any( + module_is_sequentially_offloaded(module) for _, module in self.components.items() + ) + if pipeline_is_sequentially_offloaded and device and torch.device(device).type == "cuda": + raise ValueError( + "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading." + ) + + is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1 + if is_pipeline_device_mapped: + raise ValueError( + "It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` first and then call `to()`." + ) + + # Display a warning in this case (the operation succeeds but the benefits are lost) + pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items()) + if pipeline_is_offloaded and device and torch.device(device).type == "cuda": + logger.warning( + f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading." + ) + + modules = [m for m in self.components.values() if isinstance(m, torch.nn.Module)] + + is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded + for module in modules: + is_loaded_in_8bit = hasattr(module, "is_loaded_in_8bit") and module.is_loaded_in_8bit + + if is_loaded_in_8bit and dtype is not None: + logger.warning( + f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {dtype} is not yet supported. Module is still in 8bit precision." + ) + + if is_loaded_in_8bit and device is not None: + logger.warning( + f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {dtype} via `.to()` is not yet supported. Module is still on {module.device}." + ) + else: + module.to(device, dtype) + + if ( + module.dtype == torch.float16 + and str(device) in ["cpu"] + and not silence_dtype_warnings + and not is_offloaded + ): + logger.warning( + "Pipelines loaded with `dtype=torch.float16` cannot run with `cpu` device. It" + " is not recommended to move them to `cpu` as running them will fail. Please make" + " sure to use an accelerator to run the pipeline in inference, due to the lack of" + " support for`float16` operations on this device in PyTorch. Please, remove the" + " `torch_dtype=torch.float16` argument, or use another device for inference." + ) + return self diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 0a744264b7a6..0fd640caeefb 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -351,7 +351,7 @@ def _get_pipeline_class( revision=revision, ) - if class_obj.__name__ != "DiffusionPipeline": + if class_obj.__name__ != "DiffusionPipeline" and class_obj.__name__ != "ModularPipelineBuilder": return class_obj diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0]) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/__init__.py b/src/diffusers/pipelines/stable_diffusion_xl/__init__.py index 8088fbcfceba..ab5c6bde7d54 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/__init__.py @@ -29,6 +29,17 @@ _import_structure["pipeline_stable_diffusion_xl_img2img"] = ["StableDiffusionXLImg2ImgPipeline"] _import_structure["pipeline_stable_diffusion_xl_inpaint"] = ["StableDiffusionXLInpaintPipeline"] _import_structure["pipeline_stable_diffusion_xl_instruct_pix2pix"] = ["StableDiffusionXLInstructPix2PixPipeline"] + _import_structure["pipeline_stable_diffusion_xl_modular"] = [ + "StableDiffusionXLDecodeLatentsStep", + "StableDiffusionXLDenoiseStep", + "StableDiffusionXLInputStep", + "StableDiffusionXLModularPipeline", + "StableDiffusionXLPrepareAdditionalConditioningStep", + "StableDiffusionXLPrepareLatentsStep", + "StableDiffusionXLSetTimestepsStep", + "StableDiffusionXLTextEncoderStep", + "StableDiffusionXLControlNetDenoiseStep", + ] if is_transformers_available() and is_flax_available(): from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState @@ -48,6 +59,17 @@ from .pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipeline from .pipeline_stable_diffusion_xl_inpaint import StableDiffusionXLInpaintPipeline from .pipeline_stable_diffusion_xl_instruct_pix2pix import StableDiffusionXLInstructPix2PixPipeline + from .pipeline_stable_diffusion_xl_modular import ( + StableDiffusionXLDecodeLatentsStep, + StableDiffusionXLDenoiseStep, + StableDiffusionXLInputStep, + StableDiffusionXLModularPipeline, + StableDiffusionXLPrepareAdditionalConditioningStep, + StableDiffusionXLPrepareLatentsStep, + StableDiffusionXLSetTimestepsStep, + StableDiffusionXLTextEncoderStep, + StableDiffusionXLControlNetDenoiseStep, + ) try: if not (is_transformers_available() and is_flax_available()): diff --git a/src/diffusers/pipelines/custom_pipeline_builder.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py similarity index 78% rename from src/diffusers/pipelines/custom_pipeline_builder.py rename to src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index d97cc32a4643..061b6a8dd0f5 100644 --- a/src/diffusers/pipelines/custom_pipeline_builder.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -13,29 +13,29 @@ # limitations under the License. import inspect -from dataclasses import dataclass, field -from enum import Enum -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union import PIL import torch -from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer -from ..configuration_utils import ConfigMixin -from ..guider import CFGGuider -from ..image_processor import VaeImageProcessor -from ..loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin -from ..models import ControlNetModel, ImageProjection -from ..models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor -from ..models.lora import adjust_lora_scale_text_encoder -from ..utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers -from ..utils.torch_utils import is_compiled_module, randn_tensor -from .controlnet.multicontrolnet import MultiControlNetModel -from .pipeline_loading_utils import _fetch_class_library_tuple -from .pipeline_utils import DiffusionPipeline, StableDiffusionMixin -from .stable_diffusion_xl import ( - StableDiffusionXLPipeline, +from ...guider import CFGGuider +from ...image_processor import VaeImageProcessor +from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import ControlNetModel, ImageProjection +from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor +from ...models.lora import adjust_lora_scale_text_encoder +from ...utils import ( + USE_PEFT_BACKEND, + logging, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import is_compiled_module, randn_tensor +from ..controlnet.multicontrolnet import MultiControlNetModel +from ..modular_pipeline_builder import ModularPipelineBuilder, PipelineBlock, PipelineState +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from .pipeline_output import ( StableDiffusionXLPipelineOutput, ) @@ -43,151 +43,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class CustomPipeline(ConfigMixin): - """ - Base class for all custom pipelines built with CustomPipelineBuilder. - - [`CustomPipeline`] stores all components (models, schedulers, and processors) for diffusion pipelines. Unlike - [`DiffusionPipeline`], it's designed to be used exclusively with [`CustomPipelineBuilder`] and does not have a - `__call__` method. It cannot be called directly and must be run via the builder's run_pipeline method. - Additionally, it does not include methods for loading, downloading, or saving models, focusing only on - inference-related tasks, such as: - - - move all PyTorch modules to the device of your choice - - enable/disable the progress bar for the denoising iteration - - Usage: This class should not be instantiated directly. Instead, use CustomPipelineBuilder to create and configure a - CustomPipeline instance. - - Example: - builder = CustomPipelineBuilder("SDXL") builder.add_blocks([InputStep(), TextEncoderStep(), ...]) result = - builder.run_pipeline(prompt="A beautiful sunset") - - Class Attributes: - config_name (str): Filename for the configuration storing component class and module names. - - Note: This class is part of a modular pipeline system and is intended to be used in conjunction with - CustomPipelineBuilder for maximum flexibility and customization in diffusion pipelines. - """ - - config_name = "model_index.json" - model_cpu_offload_seq = None - hf_device_map = None - _exclude_from_cpu_offload = [] - - def __init__(self): - super().__init__() - self.register_to_config() - self.builder = None - - # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline.register_modules - def register_modules(self, **kwargs): - for name, module in kwargs.items(): - # retrieve library - if module is None or isinstance(module, (tuple, list)) and module[0] is None: - register_dict = {name: (None, None)} - else: - library, class_name = _fetch_class_library_tuple(module) - register_dict = {name: (library, class_name)} - - # save model index config - self.register_to_config(**register_dict) - - # set models - setattr(self, name, module) - - @property - def device(self) -> torch.device: - r""" - Returns: - `torch.device`: The torch device on which the pipeline is located. - """ - modules = self.components.values() - modules = [m for m in modules if isinstance(m, torch.nn.Module)] - - for module in modules: - return module.device - - return torch.device("cpu") - - @property - # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline._execution_device - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - [`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from - Accelerate's module hooks. - """ - for name, model in self.components.items(): - if not isinstance(model, torch.nn.Module) or name in self._exclude_from_cpu_offload: - continue - - if not hasattr(model, "_hf_hook"): - return self.device - for module in model.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - @property - def dtype(self) -> torch.dtype: - r""" - Returns: - `torch.dtype`: The torch dtype on which the pipeline is located. - """ - modules = self.components.values() - modules = [m for m in modules if isinstance(m, torch.nn.Module)] - - for module in modules: - return module.dtype - - return torch.float32 - - @property - def components(self) -> Dict[str, Any]: - r""" - The `self.components` property returns all modules needed to initialize the pipeline, as defined by the - pipeline blocks. - - Returns (`dict`): - A dictionary containing all the components defined in the pipeline blocks. - """ - if not hasattr(self, "builder") or self.builder is None: - raise ValueError("Pipeline builder is not set. Cannot retrieve components.") - - components = {} - for block in self.builder.pipeline_blocks: - components.update(block.components) - - return components - - # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline.progress_bar - def progress_bar(self, iterable=None, total=None): - if not hasattr(self, "_progress_bar_config"): - self._progress_bar_config = {} - elif not isinstance(self._progress_bar_config, dict): - raise ValueError( - f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." - ) - - if iterable is not None: - return tqdm(iterable, **self._progress_bar_config) - elif total is not None: - return tqdm(total=total, **self._progress_bar_config) - else: - raise ValueError("Either `total` or `iterable` has to be defined.") - - # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline.set_progress_bar_config - def set_progress_bar_config(self, **kwargs): - self._progress_bar_config = kwargs - - def __call__(self, *args, **kwargs): - raise NotImplementedError("__call__ is not implemented for CustomPipeline") - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( @@ -263,2062 +118,1703 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class SDXLCustomPipeline( - CustomPipeline, - StableDiffusionMixin, - TextualInversionLoaderMixin, - StableDiffusionXLLoraLoaderMixin, -): - def __init__(self): - super().__init__() - @property - def default_sample_size(self): - default_sample_size = 128 - if hasattr(self, "unet") and self.unet is not None: - default_sample_size = self.unet.config.sample_size - return default_sample_size +class StableDiffusionXLInputStep(PipelineBlock): @property - def vae_scale_factor(self): - vae_scale_factor = 8 - if hasattr(self, "vae") and self.vae is not None: - vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - return vae_scale_factor + def inputs(self) -> List[Tuple[str, Any]]: + return [ + ("prompt", None), + ("prompt_embeds", None), + ] @property - def num_channels_latents(self): - num_channels_latents = 4 - if hasattr(self, "unet") and self.unet is not None: - num_channels_latents = self.unet.config.in_channels - return num_channels_latents + def intermediates_outputs(self) -> List[str]: + return ["batch_size"] - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids - def _get_add_time_ids( - self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None - ): - add_time_ids = list(original_size + crops_coords_top_left + target_size) + @torch.no_grad() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + prompt = state.get_input("prompt") + prompt_embeds = state.get_input("prompt_embeds") - passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim - ) - expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] - if expected_add_embed_dim != passed_add_embed_dim: - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." - ) + state.add_intermediate("batch_size", batch_size) - add_time_ids = torch.tensor([add_time_ids], dtype=dtype) - return add_time_ids + return pipeline, state - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids - def _get_add_time_ids_img2img( + +class StableDiffusionXLTextEncoderStep(PipelineBlock): + optional_components = ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"] + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + ("prompt", None), + ("prompt_2", None), + ("negative_prompt", None), + ("negative_prompt_2", None), + ("cross_attention_kwargs", None), + ("prompt_embeds", None), + ("negative_prompt_embeds", None), + ("pooled_prompt_embeds", None), + ("negative_pooled_prompt_embeds", None), + ("num_images_per_prompt", 1), + ("guidance_scale", 5.0), + ("clip_skip", None), + ] + + @property + def intermediates_outputs(self) -> List[str]: + return [ + "prompt_embeds", + "negative_prompt_embeds", + "pooled_prompt_embeds", + "negative_pooled_prompt_embeds", + ] + + def __init__( self, - original_size, - crops_coords_top_left, - target_size, - aesthetic_score, - negative_aesthetic_score, - negative_original_size, - negative_crops_coords_top_left, - negative_target_size, - dtype, - text_encoder_projection_dim=None, + text_encoder: Optional[CLIPTextModel] = None, + text_encoder_2: Optional[CLIPTextModelWithProjection] = None, + tokenizer: Optional[CLIPTokenizer] = None, + tokenizer_2: Optional[CLIPTokenizer] = None, + force_zeros_for_empty_prompt: bool = True, ): - if self.config.requires_aesthetics_score: - add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) - add_neg_time_ids = list( - negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) - ) - else: - add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) - - passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + super().__init__( + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + force_zeros_for_empty_prompt=force_zeros_for_empty_prompt, ) - expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features - if ( - expected_add_embed_dim > passed_add_embed_dim - and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim - ): + @staticmethod + def check_inputs( + pipeline, + prompt, + prompt_2, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + ): + if prompt is not None and prompt_embeds is not None: raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." ) - elif ( - expected_add_embed_dim < passed_add_embed_dim - and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim - ): + elif prompt_2 is not None and prompt_embeds is not None: raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." ) - elif expected_add_embed_dim != passed_add_embed_dim: + elif prompt is None and prompt_embeds is None: raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") - add_time_ids = torch.tensor([add_time_ids], dtype=dtype) - add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) - - return add_time_ids, add_neg_time_ids - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - dtype = next(self.image_encoder.parameters()).dtype + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." ) - return image_enc_hidden_states, uncond_image_enc_hidden_states - else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) - # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image - # return image without apply any guidance - def prepare_control_image( - self, - image, - width, - height, - batch_size, - num_images_per_prompt, - device, - dtype, - ): - image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) - image_batch_size = image.shape[0] + @torch.no_grad() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + # Get inputs + prompt = state.get_input("prompt") + prompt_2 = state.get_input("prompt_2") + negative_prompt = state.get_input("negative_prompt") + negative_prompt_2 = state.get_input("negative_prompt_2") + cross_attention_kwargs = state.get_input("cross_attention_kwargs") + prompt_embeds = state.get_input("prompt_embeds") + negative_prompt_embeds = state.get_input("negative_prompt_embeds") + pooled_prompt_embeds = state.get_input("pooled_prompt_embeds") + negative_pooled_prompt_embeds = state.get_input("negative_pooled_prompt_embeds") + num_images_per_prompt = state.get_input("num_images_per_prompt") + guidance_scale = state.get_input("guidance_scale") + clip_skip = state.get_input("clip_skip") - if image_batch_size == 1: - repeat_by = batch_size - else: - # image batch size is the same as prompt batch size - repeat_by = num_images_per_prompt + do_classifier_free_guidance = guidance_scale > 1.0 + device = pipeline._execution_device - image = image.repeat_interleave(repeat_by, dim=0) + self.check_inputs( + pipeline, + prompt, + prompt_2, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) - image = image.to(device=device, dtype=dtype) + # Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = pipeline.encode_prompt( + prompt, + prompt_2, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=clip_skip, + ) + # Add outputs + state.add_intermediate("prompt_embeds", prompt_embeds) + state.add_intermediate("negative_prompt_embeds", negative_prompt_embeds) + state.add_intermediate("pooled_prompt_embeds", pooled_prompt_embeds) + state.add_intermediate("negative_pooled_prompt_embeds", negative_pooled_prompt_embeds) + return pipeline, state - return image - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt - def encode_prompt( - self, - prompt: str, - prompt_2: Optional[str] = None, - device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, - do_classifier_free_guidance: bool = True, - negative_prompt: Optional[str] = None, - negative_prompt_2: Optional[str] = None, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - pooled_prompt_embeds: Optional[torch.Tensor] = None, - negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, - lora_scale: Optional[float] = None, - clip_skip: Optional[int] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. +class StableDiffusionXLSetTimestepsStep(PipelineBlock): + required_components = ["scheduler"] - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in both text-encoders - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - negative_prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and - `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - pooled_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` - input argument. - lora_scale (`float`, *optional*): - A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - clip_skip (`int`, *optional*): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - """ - device = device or self._execution_device + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + ("num_inference_steps", 50), + ("timesteps", None), + ("sigmas", None), + ("denoising_end", None), + ("image", None), + ("strength", 0.3), + ("denoising_start", None), + ("num_images_per_prompt", 1), + ("device", None), + ] - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): - self._lora_scale = lora_scale + @property + def intermediates_inputs(self) -> List[str]: + return ["batch_size"] - # dynamically adjust the LoRA scale - if self.text_encoder is not None: - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) + @property + def intermediates_outputs(self) -> List[str]: + return ["timesteps", "num_inference_steps", "latent_timestep"] - if self.text_encoder_2 is not None: - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - else: - scale_lora_layers(self.text_encoder_2, lora_scale) + def __init__(self, scheduler=None): + super().__init__(scheduler=scheduler) - prompt = [prompt] if isinstance(prompt, str) else prompt + @torch.no_grad() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + num_inference_steps = state.get_input("num_inference_steps") + timesteps = state.get_input("timesteps") + sigmas = state.get_input("sigmas") + denoising_end = state.get_input("denoising_end") + device = state.get_input("device") - if prompt is not None: - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] + # image to image only + image = state.get_input("image") # just to check if it is an image to image workflow + strength = state.get_input("strength") + denoising_start = state.get_input("denoising_start") + num_images_per_prompt = state.get_input("num_images_per_prompt") - # Define tokenizers and text encoders - tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] - text_encoders = ( - [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + # image to image only + batch_size = state.get_intermediate("batch_size") + + if device is None: + device = pipeline._execution_device + + timesteps, num_inference_steps = retrieve_timesteps( + pipeline.scheduler, num_inference_steps, device, timesteps, sigmas ) - if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + if image is not None: - # textual inversion: process multi-vector tokens if necessary - prompt_embeds_list = [] - prompts = [prompt, prompt_2] - for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, tokenizer) + def denoising_value_valid(dnv): + return isinstance(dnv, float) and 0 < dnv < 1 - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) + timesteps, num_inference_steps = pipeline.get_timesteps( + num_inference_steps, + strength, + device, + denoising_start=denoising_start if denoising_value_valid(denoising_start) else None, + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) - text_input_ids = text_inputs.input_ids - untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + else: + latent_timestep = None - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {tokenizer.model_max_length} tokens: {removed_text}" - ) + if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: + discrete_timestep_cutoff = int( + round( + pipeline.scheduler.config.num_train_timesteps + - (denoising_end * pipeline.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] - prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + state.add_intermediate("timesteps", timesteps) + state.add_intermediate("num_inference_steps", num_inference_steps) + state.add_intermediate("latent_timestep", latent_timestep) - # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] - if clip_skip is None: - prompt_embeds = prompt_embeds.hidden_states[-2] - else: - # "2" because SDXL always indexes from the penultimate layer. - prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + return pipeline, state - prompt_embeds_list.append(prompt_embeds) - prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) +class StableDiffusionXLPrepareLatentsStep(PipelineBlock): + optional_components = ["vae", "scheduler"] + required_auxiliaries = ["image_processor"] - # get unconditional embeddings for classifier free guidance - zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt - if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: - negative_prompt_embeds = torch.zeros_like(prompt_embeds) - negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) - elif do_classifier_free_guidance and negative_prompt_embeds is None: - negative_prompt = negative_prompt or "" - negative_prompt_2 = negative_prompt_2 or negative_prompt + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + ("height", None), + ("width", None), + ("generator", None), + ("latents", None), + ("num_images_per_prompt", 1), + ("device", None), + ("dtype", None), + ("image", None), + ("denoising_start", None), + ] - # normalize str to list - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_2 = ( - batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 - ) + @property + def intermediates_inputs(self) -> List[str]: + return ["batch_size", "latent_timestep"] - uncond_tokens: List[str] - if prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = [negative_prompt, negative_prompt_2] + @property + def intermediates_outputs(self) -> List[str]: + return ["latents"] - negative_prompt_embeds_list = [] - for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): - if isinstance(self, TextualInversionLoaderMixin): - negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + def __init__(self, vae=None, image_processor=None, vae_scale_factor=8, scheduler=None): + if image_processor is None: + image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) + super().__init__( + vae=vae, image_processor=image_processor, vae_scale_factor=vae_scale_factor, scheduler=scheduler + ) - max_length = prompt_embeds.shape[1] - uncond_input = tokenizer( - negative_prompt, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) + @staticmethod + def check_inputs(pipeline, height, width, image): + if image is not None and (height is not None or width is not None): + raise ValueError("Cannot specify both `image` and `height` or `width`") - negative_prompt_embeds = text_encoder( - uncond_input.input_ids.to(device), - output_hidden_states=True, - ) - # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] - negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + if ( + height is not None + and height % pipeline.vae_scale_factor != 0 + or width is not None + and width % pipeline.vae_scale_factor != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {pipeline.vae_scale_factor} but are {height} and {width}." + ) - negative_prompt_embeds_list.append(negative_prompt_embeds) + @torch.no_grad() + def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: + latents = state.get_input("latents") + num_images_per_prompt = state.get_input("num_images_per_prompt") + generator = state.get_input("generator") + device = state.get_input("device") + dtype = state.get_input("dtype") - negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + # text to image only + height = state.get_input("height") + width = state.get_input("width") - if self.text_encoder_2 is not None: - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) - else: - prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + # image to image only + image = state.get_input("image") + denoising_start = state.get_input("denoising_start") - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + batch_size = state.get_intermediate("batch_size") + prompt_embeds = state.get_intermediate("prompt_embeds", None) + # image to image only + latent_timestep = state.get_intermediate("latent_timestep", None) - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] + if dtype is None and prompt_embeds is not None: + dtype = prompt_embeds.dtype + elif dtype is None: + dtype = pipeline.vae.dtype - if self.text_encoder_2 is not None: - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) - else: - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + if device is None: + device = pipeline._execution_device - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + self.check_inputs(pipeline, height, width, image) - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) - if do_classifier_free_guidance: - negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 + if image is None: + height = height or pipeline.default_sample_size * pipeline.vae_scale_factor + width = width or pipeline.default_sample_size * pipeline.vae_scale_factor + num_channels_latents = pipeline.num_channels_latents + latents = pipeline.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents, ) + else: + image = pipeline.image_processor.preprocess(image) + add_noise = True if denoising_start is None else False + if latents is None: + latents = pipeline.prepare_latents_img2img( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + dtype, + device, + generator, + add_noise, + ) - if self.text_encoder is not None: - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) - - if self.text_encoder_2 is not None: - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder_2, lora_scale) - - return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds - def prepare_ip_adapter_image_embeds( - self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance - ): - image_embeds = [] - if do_classifier_free_guidance: - negative_image_embeds = [] - if ip_adapter_image_embeds is None: - if not isinstance(ip_adapter_image, list): - ip_adapter_image = [ip_adapter_image] + state.add_intermediate("latents", latents) - if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): - raise ValueError( - f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." - ) + return pipeline, state - for single_ip_adapter_image, image_proj_layer in zip( - ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers - ): - output_hidden_state = not isinstance(image_proj_layer, ImageProjection) - single_image_embeds, single_negative_image_embeds = self.encode_image( - single_ip_adapter_image, device, 1, output_hidden_state - ) - image_embeds.append(single_image_embeds[None, :]) - if do_classifier_free_guidance: - negative_image_embeds.append(single_negative_image_embeds[None, :]) - else: - for single_image_embeds in ip_adapter_image_embeds: - if do_classifier_free_guidance: - single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - negative_image_embeds.append(single_negative_image_embeds) - image_embeds.append(single_image_embeds) +class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): + required_components = ["unet"] - ip_adapter_image_embeds = [] - for i, single_image_embeds in enumerate(image_embeds): - single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) - if do_classifier_free_guidance: - single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + ("original_size", None), + ("target_size", None), + ("negative_original_size", None), + ("negative_target_size", None), + ("crops_coords_top_left", (0, 0)), + ("negative_crops_coords_top_left", (0, 0)), + ("num_images_per_prompt", 1), + ("guidance_scale", 5.0), + ("aesthetic_score", 6.0), + ("negative_aesthetic_score", 2.0), + ("device", None), + ("image", None), + ] - single_image_embeds = single_image_embeds.to(device=device) - ip_adapter_image_embeds.append(single_image_embeds) + @property + def intermediates_inputs(self) -> List[str]: + return ["latents", "batch_size", "pooled_prompt_embeds"] - return ip_adapter_image_embeds + @property + def intermediates_outputs(self) -> List[str]: + return ["add_time_ids", "negative_add_time_ids", "timestep_cond"] - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps - def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): - # get the original timestep using init_timestep - if denoising_start is None: - init_timestep = min(int(num_inference_steps * strength), num_inference_steps) - t_start = max(num_inference_steps - init_timestep, 0) + def __init__(self, unet=None, requires_aesthetics_score=False): + super().__init__(unet=unet, requires_aesthetics_score=requires_aesthetics_score) - timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] - if hasattr(self.scheduler, "set_begin_index"): - self.scheduler.set_begin_index(t_start * self.scheduler.order) + @torch.no_grad() + def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: + original_size = state.get_input("original_size") + target_size = state.get_input("target_size") + negative_original_size = state.get_input("negative_original_size") + negative_target_size = state.get_input("negative_target_size") + crops_coords_top_left = state.get_input("crops_coords_top_left") + negative_crops_coords_top_left = state.get_input("negative_crops_coords_top_left") + num_images_per_prompt = state.get_input("num_images_per_prompt") + guidance_scale = state.get_input("guidance_scale") + device = state.get_input("device") - return timesteps, num_inference_steps - t_start + # image to image only + image = state.get_input("image") + aesthetic_score = state.get_input("aesthetic_score") + negative_aesthetic_score = state.get_input("negative_aesthetic_score") - else: - # Strength is irrelevant if we directly request a timestep to start at; - # that is, strength is determined by the denoising_start instead. - discrete_timestep_cutoff = int( - round( - self.scheduler.config.num_train_timesteps - - (denoising_start * self.scheduler.config.num_train_timesteps) - ) - ) + latents = state.get_intermediate("latents") + batch_size = state.get_intermediate("batch_size") + pooled_prompt_embeds = state.get_intermediate("pooled_prompt_embeds") - num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item() - if self.scheduler.order == 2 and num_inference_steps % 2 == 0: - # if the scheduler is a 2nd order scheduler we might have to do +1 - # because `num_inference_steps` might be even given that every timestep - # (except the highest one) is duplicated. If `num_inference_steps` is even it would - # mean that we cut the timesteps in the middle of the denoising step - # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 - # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler - num_inference_steps = num_inference_steps + 1 + if device is None: + device = pipeline._execution_device - # because t_n+1 >= t_n, we slice the timesteps starting from the end - t_start = len(self.scheduler.timesteps) - num_inference_steps - timesteps = self.scheduler.timesteps[t_start:] - if hasattr(self.scheduler, "set_begin_index"): - self.scheduler.set_begin_index(t_start) - return timesteps, num_inference_steps + height, width = latents.shape[-2:] + height = height * pipeline.vae_scale_factor + width = width * pipeline.vae_scale_factor - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = ( - batch_size, - num_channels_latents, - int(height) // self.vae_scale_factor, - int(width) // self.vae_scale_factor, - ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) + original_size = original_size or (height, width) + target_size = target_size or (height, width) - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + if hasattr(pipeline, "text_encoder_2") and pipeline.text_encoder_2 is not None: + text_encoder_projection_dim = pipeline.text_encoder_2.config.projection_dim else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents - def prepare_latents_img2img( - self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True - ): - if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): - raise ValueError( - f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + if image is None: + add_time_ids = pipeline._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + pooled_prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) + add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1).to(device=device) - latents_mean = latents_std = None - if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: - latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: - latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = pipeline._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + pooled_prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + negative_add_time_ids = negative_add_time_ids.repeat(batch_size * num_images_per_prompt, 1).to( + device=device + ) + else: + if negative_original_size is None: + negative_original_size = original_size + if negative_target_size is None: + negative_target_size = target_size - # Offload text encoder if `enable_model_cpu_offload` was enabled - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.text_encoder_2.to("cpu") - torch.cuda.empty_cache() + add_time_ids, negative_add_time_ids = pipeline._get_add_time_ids_img2img( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=pooled_prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1).to(device=device) + negative_add_time_ids = negative_add_time_ids.repeat(batch_size * num_images_per_prompt, 1).to( + device=device + ) - image = image.to(device=device, dtype=dtype) + # Optionally get Guidance Scale Embedding for LCM + timestep_cond = None + if ( + hasattr(pipeline, "unet") + and pipeline.unet is not None + and pipeline.unet.config.time_cond_proj_dim is not None + ): + guidance_scale_tensor = torch.tensor(guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = pipeline.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=pipeline.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) - batch_size = batch_size * num_images_per_prompt + state.add_intermediate("add_time_ids", add_time_ids) + state.add_intermediate("negative_add_time_ids", negative_add_time_ids) + state.add_intermediate("timestep_cond", timestep_cond) + return pipeline, state - if image.shape[1] == 4: - init_latents = image - else: - # make sure the VAE is in float32 mode, as it overflows in float16 - if self.vae.config.force_upcast: - image = image.float() - self.vae.to(dtype=torch.float32) +class StableDiffusionXLDenoiseStep(PipelineBlock): + required_components = ["unet", "scheduler"] + required_auxiliaries = ["guider"] - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + ("guidance_scale", 5.0), + ("guidance_rescale", 0.0), + ("cross_attention_kwargs", None), + ("generator", None), + ("eta", 0.0), + ("guider_kwargs", None), + ] - elif isinstance(generator, list): - if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: - image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) - elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " - ) + @property + def intermediates_inputs(self) -> List[str]: + return [ + "latents", + "timesteps", + "num_inference_steps", + "pooled_prompt_embeds", + "negative_pooled_prompt_embeds", + "add_time_ids", + "negative_add_time_ids", + "timestep_cond", + "prompt_embeds", + "negative_prompt_embeds", + ] - init_latents = [ - retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(batch_size) - ] - init_latents = torch.cat(init_latents, dim=0) - else: - init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + @property + def intermediates_outputs(self) -> List[str]: + return ["latents"] - if self.vae.config.force_upcast: - self.vae.to(dtype) + def __init__(self, unet=None, scheduler=None, guider=None): + if guider is None: + guider = CFGGuider() + super().__init__(unet=unet, scheduler=scheduler, guider=guider) - init_latents = init_latents.to(dtype) - if latents_mean is not None and latents_std is not None: - latents_mean = latents_mean.to(device=device, dtype=dtype) - latents_std = latents_std.to(device=device, dtype=dtype) - init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std - else: - init_latents = self.vae.config.scaling_factor * init_latents + @torch.no_grad() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + guidance_scale = state.get_input("guidance_scale") + guidance_rescale = state.get_input("guidance_rescale") - if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: - # expand init_latents for batch_size - additional_image_per_prompt = batch_size // init_latents.shape[0] - init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) - elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." - ) - else: - init_latents = torch.cat([init_latents], dim=0) + cross_attention_kwargs = state.get_input("cross_attention_kwargs") + generator = state.get_input("generator") + eta = state.get_input("eta") + guider_kwargs = state.get_input("guider_kwargs") - if add_noise: - shape = init_latents.shape - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - # get latents - init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + batch_size = state.get_intermediate("batch_size") + prompt_embeds = state.get_intermediate("prompt_embeds") + negative_prompt_embeds = state.get_intermediate("negative_prompt_embeds") + pooled_prompt_embeds = state.get_intermediate("pooled_prompt_embeds") + negative_pooled_prompt_embeds = state.get_intermediate("negative_pooled_prompt_embeds") + add_time_ids = state.get_intermediate("add_time_ids") + negative_add_time_ids = state.get_intermediate("negative_add_time_ids") - latents = init_latents + timestep_cond = state.get_intermediate("timestep_cond") + latents = state.get_intermediate("latents") - return latents + timesteps = state.get_intermediate("timesteps") + num_inference_steps = state.get_intermediate("num_inference_steps") + disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs + # adding default guider arguments: do_classifier_free_guidance, guidance_scale, guidance_rescale + guider_kwargs = guider_kwargs or {} + guider_kwargs = { + **guider_kwargs, + "disable_guidance": disable_guidance, + "guidance_scale": guidance_scale, + "guidance_rescale": guidance_rescale, + "batch_size": batch_size, + } - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae - def upcast_vae(self): - dtype = self.vae.dtype - self.vae.to(dtype=torch.float32) - use_torch_2_0_or_xformers = isinstance( - self.vae.decoder.mid_block.attentions[0].processor, - ( - AttnProcessor2_0, - XFormersAttnProcessor, - ), + pipeline.guider.set_guider(pipeline, guider_kwargs) + # Prepare conditional inputs using the guider + prompt_embeds = pipeline.guider.prepare_input( + prompt_embeds, + negative_prompt_embeds, ) - # if xformers or torch_2_0 is used attention block does not need - # to be in float32 which can save lots of memory - if use_torch_2_0_or_xformers: - self.vae.post_quant_conv.to(dtype) - self.vae.decoder.conv_in.to(dtype) - self.vae.decoder.mid_block.to(dtype) - - # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding - def get_guidance_scale_embedding( - self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 - ) -> torch.Tensor: - """ - See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 - - Args: - w (`torch.Tensor`): - Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. - embedding_dim (`int`, *optional*, defaults to 512): - Dimension of the embeddings to generate. - dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): - Data type of the generated embeddings. - - Returns: - `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. - """ - assert len(w.shape) == 1 - w = w * 1000.0 - - half_dim = embedding_dim // 2 - emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) - emb = w.to(dtype)[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0, 1)) - assert emb.shape == (w.shape[0], embedding_dim) - return emb - - -@dataclass -class PipelineState: - """ - [`PipelineState`] stores the state of a pipeline. It is used to pass data between pipeline blocks. - """ - - inputs: Dict[str, Any] = field(default_factory=dict) - intermediates: Dict[str, Any] = field(default_factory=dict) - outputs: Dict[str, Any] = field(default_factory=dict) - - def add_input(self, key: str, value: Any): - self.inputs[key] = value - - def add_intermediate(self, key: str, value: Any): - self.intermediates[key] = value - - def add_output(self, key: str, value: Any): - self.outputs[key] = value - - def get_input(self, key: str, default: Any = None) -> Any: - return self.inputs.get(key, default) - - def get_intermediate(self, key: str, default: Any = None) -> Any: - return self.intermediates.get(key, default) - - def get_output(self, key: str, default: Any = None) -> Any: - return self.outputs.get(key, default) - - def to_dict(self) -> Dict[str, Any]: - return {**self.__dict__, "inputs": self.inputs, "intermediates": self.intermediates, "outputs": self.outputs} - - def __repr__(self): - def format_value(v): - if hasattr(v, "shape") and hasattr(v, "dtype"): - return f"Tensor(\n dtype={v.dtype}, shape={v.shape}\n {v})" - elif isinstance(v, list) and len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): - return f"[Tensor(\n dtype={v[0].dtype}, shape={v[0].shape}\n {v[0]}), ...]" - else: - return repr(v) - - inputs = "\n".join(f" {k}: {format_value(v)}" for k, v in self.inputs.items()) - intermediates = "\n".join(f" {k}: {format_value(v)}" for k, v in self.intermediates.items()) - outputs = "\n".join(f" {k}: {format_value(v)}" for k, v in self.outputs.items()) - - return ( - f"PipelineState(\n" - f" inputs={{\n{inputs}\n }},\n" - f" intermediates={{\n{intermediates}\n }},\n" - f" outputs={{\n{outputs}\n }}\n" - f")" + add_time_ids = pipeline.guider.prepare_input( + add_time_ids, + negative_add_time_ids, ) - - -class PipelineBlock: - optional_components = [] - required_components = [] - required_auxiliaries = [] - - @property - def inputs(self) -> Tuple[Tuple[str, Any], ...]: - # (input_name, default_value) - return () - - @property - def intermediates_inputs(self) -> List[str]: - return [] - - @property - def intermediates_outputs(self) -> List[str]: - return [] - - def __init__(self, **kwargs): - self.components: Dict[str, Any] = {} - self.auxiliaries: Dict[str, Any] = {} - self.configs: Dict[str, Any] = {} - - # Process kwargs - for key, value in kwargs.items(): - if key in self.required_components or key in self.optional_components: - self.components[key] = value - elif key in self.required_auxiliaries: - self.auxiliaries[key] = value - else: - self.configs[key] = value - - @classmethod - def from_pipe(cls, pipe: DiffusionPipeline, **kwargs): - """ - Create a PipelineBlock instance from a diffusion pipeline object. - - Args: - pipe: A `[DiffusionPipeline]` object. - - Returns: - PipelineBlock: An instance initialized with the pipeline's components and configurations. - """ - kwargs = kwargs.copy() - # add components - expected_components = set(cls.required_components + cls.optional_components) - # - components that are passed in kwargs - components_to_add = { - component_name: kwargs.pop(component_name) - for component_name in expected_components - if component_name in kwargs - } - # - components that are in the pipeline - for component_name, component in pipe.components.items(): - if component_name in expected_components and component_name not in components_to_add: - components_to_add[component_name] = component - - # add auxiliaries - # - auxiliaries that are passed in kwargs - auxiliaries_to_add = {k: kwargs.pop(k) for k in cls.required_auxiliaries if k in kwargs} - # - auxiliaries that are in the pipeline - for aux_name in cls.required_auxiliaries: - if hasattr(pipe, aux_name) and aux_name not in auxiliaries_to_add: - auxiliaries_to_add[aux_name] = getattr(pipe, aux_name) - block_kwargs = {**components_to_add, **auxiliaries_to_add} - - # add pipeline configs - init_params = inspect.signature(cls.__init__).parameters - # modules info are also registered in the config as tuples, e.g. {'tokenizer': ('transformers', 'CLIPTokenizer')} - # we need to exclude them for block_kwargs otherwise it will override the actual module - expected_configs = { - k - for k in pipe.config.keys() - if k in init_params and k not in expected_components and k not in cls.required_auxiliaries - } - - for config_name in expected_configs: - if config_name not in block_kwargs: - if config_name in kwargs: - # - configs that are passed in kwargs - block_kwargs[config_name] = kwargs.pop(config_name) - else: - # - configs that are in the pipeline - block_kwargs[config_name] = pipe.config[config_name] - - # Add any remaining relevant pipeline attributes - for attr_name in dir(pipe): - if attr_name not in block_kwargs and attr_name in init_params: - block_kwargs[attr_name] = getattr(pipe, attr_name) - - return cls(**block_kwargs) - - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - raise NotImplementedError("__call__ method must be implemented in subclasses") - - def __repr__(self): - class_name = self.__class__.__name__ - components = ", ".join(f"{k}={type(v).__name__}" for k, v in self.components.items()) - auxiliaries = ", ".join(f"{k}={type(v).__name__}" for k, v in self.auxiliaries.items()) - configs = ", ".join(f"{k}={v}" for k, v in self.configs.items()) - inputs = ", ".join(f"{name}={default}" for name, default in self.inputs) - intermediates_inputs = ", ".join(self.intermediates_inputs) - intermediates_outputs = ", ".join(self.intermediates_outputs) - - return ( - f"{class_name}(\n" - f" components: {components}\n" - f" auxiliaries: {auxiliaries}\n" - f" configs: {configs}\n" - f" inputs: {inputs}\n" - f" intermediates_inputs: {intermediates_inputs}\n" - f" intermediates_outputs: {intermediates_outputs}\n" - f")" + pooled_prompt_embeds = pipeline.guider.prepare_input( + pooled_prompt_embeds, + negative_pooled_prompt_embeds, ) + added_cond_kwargs = { + "text_embeds": pooled_prompt_embeds, + "time_ids": add_time_ids, + } -class InputStep(PipelineBlock): - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - ("prompt", None), - ("prompt_embeds", None), - ] - - @property - def intermediates_outputs(self) -> List[str]: - return ["batch_size"] + # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta) + num_warmup_steps = max(len(timesteps) - num_inference_steps * pipeline.scheduler.order, 0) - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - prompt = state.get_input("prompt") - prompt_embeds = state.get_input("prompt_embeds") + with pipeline.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = pipeline.guider.prepare_input(latents, latents) + latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t) + # predict the noise residual + noise_pred = pipeline.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + # perform guidance + noise_pred = pipeline.guider.apply_guidance( + noise_pred, + timestep=t, + ) + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = pipeline.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): + progress_bar.update() - state.add_intermediate("batch_size", batch_size) + pipeline.guider.reset_guider(pipeline) + state.add_intermediate("latents", latents) return pipeline, state -class TextEncoderStep(PipelineBlock): - optional_components = ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"] +class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): + required_components = ["unet", "controlnet", "scheduler"] + required_auxiliaries = ["guider", "controlnet_guider", "control_image_processor"] @property def inputs(self) -> List[Tuple[str, Any]]: return [ - ("prompt", None), - ("prompt_2", None), - ("negative_prompt", None), - ("negative_prompt_2", None), - ("cross_attention_kwargs", None), - ("prompt_embeds", None), - ("negative_prompt_embeds", None), - ("pooled_prompt_embeds", None), - ("negative_pooled_prompt_embeds", None), + ("control_image", None), + ("control_guidance_start", 0.0), + ("control_guidance_end", 1.0), + ("controlnet_conditioning_scale", 1.0), + ("guess_mode", False), ("num_images_per_prompt", 1), ("guidance_scale", 5.0), - ("clip_skip", None), + ("guidance_rescale", 0.0), + ("cross_attention_kwargs", None), + ("generator", None), + ("eta", 0.0), + ("guider_kwargs", None), ] @property - def intermediates_outputs(self) -> List[str]: + def intermediates_inputs(self) -> List[str]: return [ + "latents", + "batch_size", + "timesteps", + "num_inference_steps", "prompt_embeds", "negative_prompt_embeds", + "add_time_ids", + "negative_add_time_ids", "pooled_prompt_embeds", "negative_pooled_prompt_embeds", + "timestep_cond", ] + @property + def intermediates_outputs(self) -> List[str]: + return ["latents"] + def __init__( self, - text_encoder: Optional[CLIPTextModel] = None, - text_encoder_2: Optional[CLIPTextModelWithProjection] = None, - tokenizer: Optional[CLIPTokenizer] = None, - tokenizer_2: Optional[CLIPTokenizer] = None, - force_zeros_for_empty_prompt: bool = True, + unet=None, + controlnet=None, + scheduler=None, + guider=None, + controlnet_guider=None, + control_image_processor=None, + vae_scale_factor=8.0, ): + if guider is None: + guider = CFGGuider() + if controlnet_guider is None: + controlnet_guider = CFGGuider() + if control_image_processor is None: + control_image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor, do_convert_rgb=True, do_normalize=False) super().__init__( - text_encoder=text_encoder, - text_encoder_2=text_encoder_2, - tokenizer=tokenizer, - tokenizer_2=tokenizer_2, - force_zeros_for_empty_prompt=force_zeros_for_empty_prompt, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + guider=guider, + controlnet_guider=controlnet_guider, + control_image_processor=control_image_processor, + vae_scale_factor=vae_scale_factor, ) - @staticmethod - def check_inputs( - pipeline, - prompt, - prompt_2, - negative_prompt=None, - negative_prompt_2=None, - prompt_embeds=None, - negative_prompt_embeds=None, - pooled_prompt_embeds=None, - negative_pooled_prompt_embeds=None, - ): - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt_2 is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): - raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - elif negative_prompt_2 is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - if prompt_embeds is not None and pooled_prompt_embeds is None: - raise ValueError( - "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." - ) - - if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: - raise ValueError( - "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." - ) - @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: - # Get inputs - prompt = state.get_input("prompt") - prompt_2 = state.get_input("prompt_2") - negative_prompt = state.get_input("negative_prompt") - negative_prompt_2 = state.get_input("negative_prompt_2") + guidance_scale = state.get_input("guidance_scale") + guidance_rescale = state.get_input("guidance_rescale") cross_attention_kwargs = state.get_input("cross_attention_kwargs") - prompt_embeds = state.get_input("prompt_embeds") - negative_prompt_embeds = state.get_input("negative_prompt_embeds") - pooled_prompt_embeds = state.get_input("pooled_prompt_embeds") - negative_pooled_prompt_embeds = state.get_input("negative_pooled_prompt_embeds") + guider_kwargs = state.get_input("guider_kwargs") + generator = state.get_input("generator") + eta = state.get_input("eta") num_images_per_prompt = state.get_input("num_images_per_prompt") - guidance_scale = state.get_input("guidance_scale") - clip_skip = state.get_input("clip_skip") - - do_classifier_free_guidance = guidance_scale > 1.0 - device = pipeline._execution_device - - self.check_inputs( - pipeline, - prompt, - prompt_2, - negative_prompt, - negative_prompt_2, - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) - - # Encode input prompt - text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None - ) - ( - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) = pipeline.encode_prompt( - prompt, - prompt_2, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - negative_prompt_2, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - lora_scale=text_encoder_lora_scale, - clip_skip=clip_skip, - ) - # Add outputs - state.add_intermediate("prompt_embeds", prompt_embeds) - state.add_intermediate("negative_prompt_embeds", negative_prompt_embeds) - state.add_intermediate("pooled_prompt_embeds", pooled_prompt_embeds) - state.add_intermediate("negative_pooled_prompt_embeds", negative_pooled_prompt_embeds) - return pipeline, state - - -class SetTimestepsStep(PipelineBlock): - required_components = ["scheduler"] + # controlnet-specific inputs + control_image = state.get_input("control_image") + control_guidance_start = state.get_input("control_guidance_start") + control_guidance_end = state.get_input("control_guidance_end") + controlnet_conditioning_scale = state.get_input("controlnet_conditioning_scale") + guess_mode = state.get_input("guess_mode") - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - ("num_inference_steps", 50), - ("timesteps", None), - ("sigmas", None), - ("denoising_end", None), - ("image", None), - ("strength", 0.3), - ("denoising_start", None), - ("num_images_per_prompt", 1), - ("device", None), - ] + batch_size = state.get_intermediate("batch_size") + latents = state.get_intermediate("latents") + timesteps = state.get_intermediate("timesteps") + num_inference_steps = state.get_intermediate("num_inference_steps") - @property - def intermediates_inputs(self) -> List[str]: - return ["batch_size"] + prompt_embeds = state.get_intermediate("prompt_embeds") + negative_prompt_embeds = state.get_intermediate("negative_prompt_embeds") + pooled_prompt_embeds = state.get_intermediate("pooled_prompt_embeds") + negative_pooled_prompt_embeds = state.get_intermediate("negative_pooled_prompt_embeds") + add_time_ids = state.get_intermediate("add_time_ids") + negative_add_time_ids = state.get_intermediate("negative_add_time_ids") - @property - def intermediates_outputs(self) -> List[str]: - return ["timesteps", "num_inference_steps", "latent_timestep"] + timestep_cond = state.get_intermediate("timestep_cond") - def __init__(self, scheduler=None): - super().__init__(scheduler=scheduler) + device = pipeline._execution_device - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - num_inference_steps = state.get_input("num_inference_steps") - timesteps = state.get_input("timesteps") - sigmas = state.get_input("sigmas") - denoising_end = state.get_input("denoising_end") - device = state.get_input("device") + height, width = latents.shape[-2:] + height = height * pipeline.vae_scale_factor + width = width * pipeline.vae_scale_factor - # image to image only - image = state.get_input("image") # just to check if it is an image to image workflow - strength = state.get_input("strength") - denoising_start = state.get_input("denoising_start") - num_images_per_prompt = state.get_input("num_images_per_prompt") + # prepare controlnet inputs + controlnet = pipeline.controlnet._orig_mod if is_compiled_module(pipeline.controlnet) else pipeline.controlnet - # image to image only - batch_size = state.get_intermediate("batch_size") + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) - if device is None: - device = pipeline._execution_device + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) - timesteps, num_inference_steps = retrieve_timesteps( - pipeline.scheduler, num_inference_steps, device, timesteps, sigmas + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions ) + guess_mode = guess_mode or global_pool_conditions - if image is not None: - - def denoising_value_valid(dnv): - return isinstance(dnv, float) and 0 < dnv < 1 - - timesteps, num_inference_steps = pipeline.get_timesteps( - num_inference_steps, - strength, - device, - denoising_start=denoising_start if denoising_value_valid(denoising_start) else None, + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + control_image = pipeline.prepare_control_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, ) - latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) - - else: - latent_timestep = None + elif isinstance(controlnet, MultiControlNetModel): + control_images = [] - if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: - discrete_timestep_cutoff = int( - round( - pipeline.scheduler.config.num_train_timesteps - - (denoising_end * pipeline.scheduler.config.num_train_timesteps) + for control_image_ in control_image: + control_image = pipeline.prepare_control_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, ) - ) - num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) - timesteps = timesteps[:num_inference_steps] - state.add_intermediate("timesteps", timesteps) - state.add_intermediate("num_inference_steps", num_inference_steps) - state.add_intermediate("latent_timestep", latent_timestep) + control_images.append(control_image) - return pipeline, state + control_image = control_images + else: + assert False + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) -class PrepareLatentsStep(PipelineBlock): - optional_components = ["vae", "scheduler"] - required_auxiliaries = ["image_processor"] - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - ("height", None), - ("width", None), - ("generator", None), - ("latents", None), - ("num_images_per_prompt", 1), - ("device", None), - ("dtype", None), - ("image", None), - ("denoising_start", None), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return ["batch_size", "latent_timestep"] - - @property - def intermediates_outputs(self) -> List[str]: - return ["latents"] - - def __init__(self, vae=None, image_processor=None, vae_scale_factor=8, scheduler=None): - if image_processor is None: - image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) - super().__init__( - vae=vae, image_processor=image_processor, vae_scale_factor=vae_scale_factor, scheduler=scheduler + # Prepare conditional inputs for unet using the guider + # adding default guider arguments: disable_guidance, guidance_scale, guidance_rescale + disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False + guider_kwargs = guider_kwargs or {} + guider_kwargs = { + **guider_kwargs, + "disable_guidance": disable_guidance, + "guidance_scale": guidance_scale, + "guidance_rescale": guidance_rescale, + "batch_size": batch_size, + } + pipeline.guider.set_guider(pipeline, guider_kwargs) + prompt_embeds = pipeline.guider.prepare_input( + prompt_embeds, + negative_prompt_embeds, + ) + add_time_ids = pipeline.guider.prepare_input( + add_time_ids, + negative_add_time_ids, + ) + pooled_prompt_embeds = pipeline.guider.prepare_input( + pooled_prompt_embeds, + negative_pooled_prompt_embeds, ) - @staticmethod - def check_inputs(pipeline, height, width, image): - if image is not None and (height is not None or width is not None): - raise ValueError("Cannot specify both `image` and `height` or `width`") - - if ( - height is not None - and height % pipeline.vae_scale_factor != 0 - or width is not None - and width % pipeline.vae_scale_factor != 0 - ): - raise ValueError( - f"`height` and `width` have to be divisible by {pipeline.vae_scale_factor} but are {height} and {width}." - ) + added_cond_kwargs = { + "text_embeds": pooled_prompt_embeds, + "time_ids": add_time_ids, + } - @torch.no_grad() - def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - latents = state.get_input("latents") - num_images_per_prompt = state.get_input("num_images_per_prompt") - generator = state.get_input("generator") - device = state.get_input("device") - dtype = state.get_input("dtype") + # Prepare conditional inputs for controlnet using the guider + controlnet_disable_guidance = True if disable_guidance or guess_mode else False + controlnet_guider_kwargs = guider_kwargs or {} + controlnet_guider_kwargs = { + **controlnet_guider_kwargs, + "disable_guidance": controlnet_disable_guidance, + "guidance_scale": guidance_scale, + "guidance_rescale": guidance_rescale, + "batch_size": batch_size, + } + pipeline.controlnet_guider.set_guider(pipeline, controlnet_guider_kwargs) + controlnet_prompt_embeds = pipeline.controlnet_guider.prepare_input(prompt_embeds) + controlnet_added_cond_kwargs = { + "text_embeds": pipeline.controlnet_guider.prepare_input(pooled_prompt_embeds), + "time_ids": pipeline.controlnet_guider.prepare_input(add_time_ids), + } + # controlnet-specific inputs: control_image + control_image = pipeline.controlnet_guider.prepare_input(control_image, control_image) - # text to image only - height = state.get_input("height") - width = state.get_input("width") + # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta) + num_warmup_steps = max(len(timesteps) - num_inference_steps * pipeline.scheduler.order, 0) - # image to image only - image = state.get_input("image") - denoising_start = state.get_input("denoising_start") + with pipeline.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # prepare latents for unet using the guider + latent_model_input = pipeline.guider.prepare_input(latents, latents) - batch_size = state.get_intermediate("batch_size") - prompt_embeds = state.get_intermediate("prompt_embeds", None) - # image to image only - latent_timestep = state.get_intermediate("latent_timestep", None) + # prepare latents for controlnet using the guider + control_model_input = pipeline.controlnet_guider.prepare_input(latents, latents) - if dtype is None and prompt_embeds is not None: - dtype = prompt_embeds.dtype - elif dtype is None: - dtype = pipeline.vae.dtype + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + down_block_res_samples, mid_block_res_sample = pipeline.controlnet( + pipeline.scheduler.scale_model_input(control_model_input, t), + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=control_image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + ) - if device is None: - device = pipeline._execution_device + # when we apply guidance for unet, but not for controlnet: + # add 0 to the unconditional batch + down_block_res_samples = pipeline.guider.prepare_input( + down_block_res_samples, [torch.zeros_like(d) for d in down_block_res_samples] + ) + mid_block_res_sample = pipeline.guider.prepare_input( + mid_block_res_sample, torch.zeros_like(mid_block_res_sample) + ) - self.check_inputs(pipeline, height, width, image) + noise_pred = pipeline.unet( + pipeline.scheduler.scale_model_input(latent_model_input, t), + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + return_dict=False, + )[0] + # perform guidance + noise_pred = pipeline.guider.apply_guidance(noise_pred, timestep=t) + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = pipeline.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) - if image is None: - height = height or pipeline.default_sample_size * pipeline.vae_scale_factor - width = width or pipeline.default_sample_size * pipeline.vae_scale_factor - num_channels_latents = pipeline.num_channels_latents - latents = pipeline.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - dtype, - device, - generator, - latents, - ) - else: - image = pipeline.image_processor.preprocess(image) - add_noise = True if denoising_start is None else False - if latents is None: - latents = pipeline.prepare_latents_img2img( - image, - latent_timestep, - batch_size, - num_images_per_prompt, - dtype, - device, - generator, - add_noise, - ) + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): + progress_bar.update() + pipeline.guider.reset_guider(pipeline) + pipeline.controlnet_guider.reset_guider(pipeline) state.add_intermediate("latents", latents) return pipeline, state -class PrepareAdditionalConditioningStep(PipelineBlock): - required_components = ["unet"] +class StableDiffusionXLDecodeLatentsStep(PipelineBlock): + optional_components = ["vae"] + required_auxiliaries = ["image_processor"] @property def inputs(self) -> List[Tuple[str, Any]]: return [ - ("original_size", None), - ("target_size", None), - ("negative_original_size", None), - ("negative_target_size", None), - ("crops_coords_top_left", (0, 0)), - ("negative_crops_coords_top_left", (0, 0)), - ("num_images_per_prompt", 1), - ("guidance_scale", 5.0), - ("aesthetic_score", 6.0), - ("negative_aesthetic_score", 2.0), - ("device", None), - ("image", None), + ("output_type", "pil"), + ("return_dict", True), ] @property def intermediates_inputs(self) -> List[str]: - return ["latents", "batch_size", "pooled_prompt_embeds"] + return ["latents"] @property def intermediates_outputs(self) -> List[str]: - return ["add_time_ids", "negative_add_time_ids", "timestep_cond"] + return ["images"] - def __init__(self, unet=None, requires_aesthetics_score=False): - super().__init__(unet=unet, requires_aesthetics_score=requires_aesthetics_score) + def __init__(self, vae=None, image_processor=None, vae_scale_factor=8): + if image_processor is None: + image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) + super().__init__(vae=vae, image_processor=image_processor, vae_scale_factor=vae_scale_factor) @torch.no_grad() - def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - original_size = state.get_input("original_size") - target_size = state.get_input("target_size") - negative_original_size = state.get_input("negative_original_size") - negative_target_size = state.get_input("negative_target_size") - crops_coords_top_left = state.get_input("crops_coords_top_left") - negative_crops_coords_top_left = state.get_input("negative_crops_coords_top_left") - num_images_per_prompt = state.get_input("num_images_per_prompt") - guidance_scale = state.get_input("guidance_scale") - device = state.get_input("device") - - # image to image only - image = state.get_input("image") - aesthetic_score = state.get_input("aesthetic_score") - negative_aesthetic_score = state.get_input("negative_aesthetic_score") + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + output_type = state.get_input("output_type") + return_dict = state.get_input("return_dict") latents = state.get_intermediate("latents") - batch_size = state.get_intermediate("batch_size") - pooled_prompt_embeds = state.get_intermediate("pooled_prompt_embeds") - - if device is None: - device = pipeline._execution_device - height, width = latents.shape[-2:] - height = height * pipeline.vae_scale_factor - width = width * pipeline.vae_scale_factor + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = pipeline.vae.dtype == torch.float16 and pipeline.vae.config.force_upcast - original_size = original_size or (height, width) - target_size = target_size or (height, width) - - if hasattr(pipeline, "text_encoder_2") and pipeline.text_encoder_2 is not None: - text_encoder_projection_dim = pipeline.text_encoder_2.config.projection_dim - else: - text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + if needs_upcasting: + pipeline.upcast_vae() + latents = latents.to(next(iter(pipeline.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != pipeline.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + pipeline.vae = pipeline.vae.to(latents.dtype) - if image is None: - add_time_ids = pipeline._get_add_time_ids( - original_size, - crops_coords_top_left, - target_size, - pooled_prompt_embeds.dtype, - text_encoder_projection_dim=text_encoder_projection_dim, + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = ( + hasattr(pipeline.vae.config, "latents_mean") and pipeline.vae.config.latents_mean is not None ) - add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1).to(device=device) - - if negative_original_size is not None and negative_target_size is not None: - negative_add_time_ids = pipeline._get_add_time_ids( - negative_original_size, - negative_crops_coords_top_left, - negative_target_size, - pooled_prompt_embeds.dtype, - text_encoder_projection_dim=text_encoder_projection_dim, + has_latents_std = ( + hasattr(pipeline.vae.config, "latents_std") and pipeline.vae.config.latents_std is not None + ) + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(pipeline.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) + latents = latents * latents_std / pipeline.vae.config.scaling_factor + latents_mean else: - negative_add_time_ids = add_time_ids - negative_add_time_ids = negative_add_time_ids.repeat(batch_size * num_images_per_prompt, 1).to( - device=device - ) + latents = latents / pipeline.vae.config.scaling_factor + + image = pipeline.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + pipeline.vae.to(dtype=torch.float16) else: - if negative_original_size is None: - negative_original_size = original_size - if negative_target_size is None: - negative_target_size = target_size + image = latents - add_time_ids, negative_add_time_ids = pipeline._get_add_time_ids_img2img( - original_size, - crops_coords_top_left, - target_size, - aesthetic_score, - negative_aesthetic_score, - negative_original_size, - negative_crops_coords_top_left, - negative_target_size, - dtype=pooled_prompt_embeds.dtype, - text_encoder_projection_dim=text_encoder_projection_dim, - ) - add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1).to(device=device) - negative_add_time_ids = negative_add_time_ids.repeat(batch_size * num_images_per_prompt, 1).to( - device=device - ) + # apply watermark if available + if hasattr(pipeline, "watermark") and pipeline.watermark is not None: + image = pipeline.watermark.apply_watermark(image) - # Optionally get Guidance Scale Embedding for LCM - timestep_cond = None - if ( - hasattr(pipeline, "unet") - and pipeline.unet is not None - and pipeline.unet.config.time_cond_proj_dim is not None - ): - guidance_scale_tensor = torch.tensor(guidance_scale - 1).repeat(batch_size * num_images_per_prompt) - timestep_cond = pipeline.get_guidance_scale_embedding( - guidance_scale_tensor, embedding_dim=pipeline.unet.config.time_cond_proj_dim - ).to(device=device, dtype=latents.dtype) + image = pipeline.image_processor.postprocess(image, output_type=output_type) + + if not return_dict: + output = (image,) + else: + output = StableDiffusionXLPipelineOutput(images=image) + + state.add_intermediate("images", image) + state.add_output("images", output) - state.add_intermediate("add_time_ids", add_time_ids) - state.add_intermediate("negative_add_time_ids", negative_add_time_ids) - state.add_intermediate("timestep_cond", timestep_cond) return pipeline, state -class DenoiseStep(PipelineBlock): - required_components = ["unet", "scheduler"] - required_auxiliaries = ["guider"] + +class StableDiffusionXLModularPipeline( + ModularPipelineBuilder, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, +): + + default_pipeline_blocks = [ + StableDiffusionXLInputStep, + StableDiffusionXLTextEncoderStep, + StableDiffusionXLSetTimestepsStep, + StableDiffusionXLPrepareLatentsStep, + StableDiffusionXLPrepareAdditionalConditioningStep, + StableDiffusionXLDenoiseStep, + StableDiffusionXLDecodeLatentsStep + ] + + def __init__(self): + super().__init__() @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - ("guidance_scale", 5.0), - ("guidance_rescale", 0.0), - ("cross_attention_kwargs", None), - ("generator", None), - ("eta", 0.0), - ("guider_kwargs", None), - ] + def default_sample_size(self): + default_sample_size = 128 + if hasattr(self, "unet") and self.unet is not None: + default_sample_size = self.unet.config.sample_size + return default_sample_size @property - def intermediates_inputs(self) -> List[str]: - return [ - "latents", - "timesteps", - "num_inference_steps", - "pooled_prompt_embeds", - "negative_pooled_prompt_embeds", - "add_time_ids", - "negative_add_time_ids", - "timestep_cond", - "prompt_embeds", - "negative_prompt_embeds", - ] + def vae_scale_factor(self): + vae_scale_factor = 8 + if hasattr(self, "vae") and self.vae is not None: + vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + return vae_scale_factor @property - def intermediates_outputs(self) -> List[str]: - return ["latents"] + def num_channels_latents(self): + num_channels_latents = 4 + if hasattr(self, "unet") and self.unet is not None: + num_channels_latents = self.unet.config.in_channels + return num_channels_latents - def __init__(self, unet=None, scheduler=None, guider=None): - if guider is None: - guider = CFGGuider() - super().__init__(unet=unet, scheduler=scheduler, guider=guider) + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - guidance_scale = state.get_input("guidance_scale") - guidance_rescale = state.get_input("guidance_rescale") + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids + def _get_add_time_ids_img2img( + self, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype, + text_encoder_projection_dim=None, + ): + if self.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list( + negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) + ) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image + # return image without apply any guidance + def prepare_control_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + return image + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device - cross_attention_kwargs = state.get_input("cross_attention_kwargs") - generator = state.get_input("generator") - eta = state.get_input("eta") - guider_kwargs = state.get_input("guider_kwargs") + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale - batch_size = state.get_intermediate("batch_size") - prompt_embeds = state.get_intermediate("prompt_embeds") - negative_prompt_embeds = state.get_intermediate("negative_prompt_embeds") - pooled_prompt_embeds = state.get_intermediate("pooled_prompt_embeds") - negative_pooled_prompt_embeds = state.get_intermediate("negative_pooled_prompt_embeds") - add_time_ids = state.get_intermediate("add_time_ids") - negative_add_time_ids = state.get_intermediate("negative_add_time_ids") + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) - timestep_cond = state.get_intermediate("timestep_cond") - latents = state.get_intermediate("latents") + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) - timesteps = state.get_intermediate("timesteps") - num_inference_steps = state.get_intermediate("num_inference_steps") - disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False + prompt = [prompt] if isinstance(prompt, str) else prompt - # adding default guider arguments: do_classifier_free_guidance, guidance_scale, guidance_rescale - guider_kwargs = guider_kwargs or {} - guider_kwargs = { - **guider_kwargs, - "disable_guidance": disable_guidance, - "guidance_scale": guidance_scale, - "guidance_rescale": guidance_rescale, - "batch_size": batch_size, - } + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] - pipeline.guider.set_guider(pipeline, guider_kwargs) - # Prepare conditional inputs using the guider - prompt_embeds = pipeline.guider.prepare_input( - prompt_embeds, - negative_prompt_embeds, - ) - add_time_ids = pipeline.guider.prepare_input( - add_time_ids, - negative_add_time_ids, - ) - pooled_prompt_embeds = pipeline.guider.prepare_input( - pooled_prompt_embeds, - negative_pooled_prompt_embeds, + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] ) - added_cond_kwargs = { - "text_embeds": pooled_prompt_embeds, - "time_ids": add_time_ids, - } + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta) - num_warmup_steps = max(len(timesteps) - num_inference_steps * pipeline.scheduler.order, 0) + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) - with pipeline.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = pipeline.guider.prepare_input(latents, latents) - latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t) - # predict the noise residual - noise_pred = pipeline.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - timestep_cond=timestep_cond, - cross_attention_kwargs=cross_attention_kwargs, - added_cond_kwargs=added_cond_kwargs, - return_dict=False, - )[0] - # perform guidance - noise_pred = pipeline.guider.apply_guidance( - noise_pred, - timestep=t, + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", ) - # compute the previous noisy sample x_t -> x_t-1 - latents_dtype = latents.dtype - latents = pipeline.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - if latents.dtype != latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - latents = latents.to(latents_dtype) - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): - progress_bar.update() + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - pipeline.guider.reset_guider(pipeline) - state.add_intermediate("latents", latents) + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) - return pipeline, state + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] -class ControlNetDenoiseStep(PipelineBlock): - required_components = ["unet", "controlnet", "scheduler"] - required_auxiliaries = ["guider", "controlnet_guider", "control_image_processor"] + prompt_embeds_list.append(prompt_embeds) - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - ("control_image", None), - ("control_guidance_start", 0.0), - ("control_guidance_end", 1.0), - ("controlnet_conditioning_scale", 1.0), - ("guess_mode", False), - ("num_images_per_prompt", 1), - ("guidance_scale", 5.0), - ("guidance_rescale", 0.0), - ("cross_attention_kwargs", None), - ("generator", None), - ("eta", 0.0), - ("guider_kwargs", None), - ] + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) - @property - def intermediates_inputs(self) -> List[str]: - return [ - "latents", - "batch_size", - "timesteps", - "num_inference_steps", - "prompt_embeds", - "negative_prompt_embeds", - "add_time_ids", - "negative_add_time_ids", - "pooled_prompt_embeds", - "negative_pooled_prompt_embeds", - "timestep_cond", - ] + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt - @property - def intermediates_outputs(self) -> List[str]: - return ["latents"] + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) - def __init__( - self, - unet=None, - controlnet=None, - scheduler=None, - guider=None, - controlnet_guider=None, - control_image_processor=None, - vae_scale_factor=8.0, - ): - if guider is None: - guider = CFGGuider() - if controlnet_guider is None: - controlnet_guider = CFGGuider() - if control_image_processor is None: - control_image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) - super().__init__( - unet=unet, - controlnet=controlnet, - scheduler=scheduler, - guider=guider, - controlnet_guider=controlnet_guider, - control_image_processor=control_image_processor, - vae_scale_factor=vae_scale_factor, - ) + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - guidance_scale = state.get_input("guidance_scale") - guidance_rescale = state.get_input("guidance_rescale") - cross_attention_kwargs = state.get_input("cross_attention_kwargs") - guider_kwargs = state.get_input("guider_kwargs") - generator = state.get_input("generator") - eta = state.get_input("eta") - num_images_per_prompt = state.get_input("num_images_per_prompt") - # controlnet-specific inputs - control_image = state.get_input("control_image") - control_guidance_start = state.get_input("control_guidance_start") - control_guidance_end = state.get_input("control_guidance_end") - controlnet_conditioning_scale = state.get_input("controlnet_conditioning_scale") - guess_mode = state.get_input("guess_mode") + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) - batch_size = state.get_intermediate("batch_size") - latents = state.get_intermediate("latents") - timesteps = state.get_intermediate("timesteps") - num_inference_steps = state.get_intermediate("num_inference_steps") + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) - prompt_embeds = state.get_intermediate("prompt_embeds") - negative_prompt_embeds = state.get_intermediate("negative_prompt_embeds") - pooled_prompt_embeds = state.get_intermediate("pooled_prompt_embeds") - negative_pooled_prompt_embeds = state.get_intermediate("negative_pooled_prompt_embeds") - add_time_ids = state.get_intermediate("add_time_ids") - negative_add_time_ids = state.get_intermediate("negative_add_time_ids") + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] - timestep_cond = state.get_intermediate("timestep_cond") + negative_prompt_embeds_list.append(negative_prompt_embeds) - device = pipeline._execution_device + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - height, width = latents.shape[-2:] - height = height * pipeline.vae_scale_factor - width = width * pipeline.vae_scale_factor + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) - # prepare controlnet inputs - controlnet = pipeline.controlnet._orig_mod if is_compiled_module(pipeline.controlnet) else pipeline.controlnet + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - # align format for control guidance - if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): - control_guidance_start = len(control_guidance_end) * [control_guidance_start] - elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): - control_guidance_end = len(control_guidance_start) * [control_guidance_end] - elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): - mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 - control_guidance_start, control_guidance_end = ( - mult * [control_guidance_start], - mult * [control_guidance_end], - ) + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] - if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): - controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) - global_pool_conditions = ( - controlnet.config.global_pool_conditions - if isinstance(controlnet, ControlNetModel) - else controlnet.nets[0].config.global_pool_conditions - ) - guess_mode = guess_mode or global_pool_conditions + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - # 4. Prepare image - if isinstance(controlnet, ControlNetModel): - control_image = pipeline.prepare_control_image( - image=control_image, - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=controlnet.dtype, + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 ) - elif isinstance(controlnet, MultiControlNetModel): - control_images = [] - for control_image_ in control_image: - control_image = pipeline.prepare_control_image( - image=control_image_, - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=controlnet.dtype, - ) + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) - control_images.append(control_image) + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) - control_image = control_images - else: - assert False + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds - # 7.1 Create tensor stating which controlnets to keep - controlnet_keep = [] - for i in range(len(timesteps)): - keeps = [ - 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) - for s, e in zip(control_guidance_start, control_guidance_end) - ] - controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] - # Prepare conditional inputs for unet using the guider - # adding default guider arguments: disable_guidance, guidance_scale, guidance_rescale - disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False - guider_kwargs = guider_kwargs or {} - guider_kwargs = { - **guider_kwargs, - "disable_guidance": disable_guidance, - "guidance_scale": guidance_scale, - "guidance_rescale": guidance_rescale, - "batch_size": batch_size, - } - pipeline.guider.set_guider(pipeline, guider_kwargs) - prompt_embeds = pipeline.guider.prepare_input( - prompt_embeds, - negative_prompt_embeds, - ) - add_time_ids = pipeline.guider.prepare_input( - add_time_ids, - negative_add_time_ids, - ) - pooled_prompt_embeds = pipeline.guider.prepare_input( - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) - added_cond_kwargs = { - "text_embeds": pooled_prompt_embeds, - "time_ids": add_time_ids, - } + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) - # Prepare conditional inputs for controlnet using the guider - controlnet_disable_guidance = True if disable_guidance or guess_mode else False - controlnet_guider_kwargs = guider_kwargs or {} - controlnet_guider_kwargs = { - **controlnet_guider_kwargs, - "disable_guidance": controlnet_disable_guidance, - "guidance_scale": guidance_scale, - "guidance_rescale": guidance_rescale, - "batch_size": batch_size, - } - pipeline.controlnet_guider.set_guider(pipeline, controlnet_guider_kwargs) - controlnet_prompt_embeds = pipeline.controlnet_guider.prepare_input(prompt_embeds) - controlnet_added_cond_kwargs = { - "text_embeds": pipeline.controlnet_guider.prepare_input(pooled_prompt_embeds), - "time_ids": pipeline.controlnet_guider.prepare_input(add_time_ids), - } - # controlnet-specific inputs: control_image - control_image = pipeline.controlnet_guider.prepare_input(control_image, control_image) + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) - # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta) - num_warmup_steps = max(len(timesteps) - num_inference_steps * pipeline.scheduler.order, 0) + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) - with pipeline.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # prepare latents for unet using the guider - latent_model_input = pipeline.guider.prepare_input(latents, latents) + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) - # prepare latents for controlnet using the guider - control_model_input = pipeline.controlnet_guider.prepare_input(latents, latents) + return ip_adapter_image_embeds - if isinstance(controlnet_keep[i], list): - cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] - else: - controlnet_cond_scale = controlnet_conditioning_scale - if isinstance(controlnet_cond_scale, list): - controlnet_cond_scale = controlnet_cond_scale[0] - cond_scale = controlnet_cond_scale * controlnet_keep[i] - down_block_res_samples, mid_block_res_sample = pipeline.controlnet( - pipeline.scheduler.scale_model_input(control_model_input, t), - t, - encoder_hidden_states=controlnet_prompt_embeds, - controlnet_cond=control_image, - conditioning_scale=cond_scale, - guess_mode=guess_mode, - added_cond_kwargs=controlnet_added_cond_kwargs, - return_dict=False, - ) + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): + # get the original timestep using init_timestep + if denoising_start is None: + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) - # when we apply guidance for unet, but not for controlnet: - # add 0 to the unconditional batch - down_block_res_samples = pipeline.guider.prepare_input( - down_block_res_samples, [torch.zeros_like(d) for d in down_block_res_samples] - ) - mid_block_res_sample = pipeline.guider.prepare_input( - mid_block_res_sample, torch.zeros_like(mid_block_res_sample) - ) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) - noise_pred = pipeline.unet( - pipeline.scheduler.scale_model_input(latent_model_input, t), - t, - encoder_hidden_states=prompt_embeds, - timestep_cond=timestep_cond, - cross_attention_kwargs=cross_attention_kwargs, - added_cond_kwargs=added_cond_kwargs, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, - return_dict=False, - )[0] - # perform guidance - noise_pred = pipeline.guider.apply_guidance(noise_pred, timestep=t) - # compute the previous noisy sample x_t -> x_t-1 - latents_dtype = latents.dtype - latents = pipeline.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - if latents.dtype != latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - latents = latents.to(latents_dtype) + return timesteps, num_inference_steps - t_start - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): - progress_bar.update() + else: + # Strength is irrelevant if we directly request a timestep to start at; + # that is, strength is determined by the denoising_start instead. + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_start * self.scheduler.config.num_train_timesteps) + ) + ) - pipeline.guider.reset_guider(pipeline) - pipeline.controlnet_guider.reset_guider(pipeline) - state.add_intermediate("latents", latents) + num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item() + if self.scheduler.order == 2 and num_inference_steps % 2 == 0: + # if the scheduler is a 2nd order scheduler we might have to do +1 + # because `num_inference_steps` might be even given that every timestep + # (except the highest one) is duplicated. If `num_inference_steps` is even it would + # mean that we cut the timesteps in the middle of the denoising step + # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 + # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler + num_inference_steps = num_inference_steps + 1 - return pipeline, state + # because t_n+1 >= t_n, we slice the timesteps starting from the end + t_start = len(self.scheduler.timesteps) - num_inference_steps + timesteps = self.scheduler.timesteps[t_start:] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start) + return timesteps, num_inference_steps + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) -class DecodeLatentsStep(PipelineBlock): - optional_components = ["vae"] - required_auxiliaries = ["image_processor"] + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - ("output_type", "pil"), - ("return_dict", True), - ] + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents - @property - def intermediates_inputs(self) -> List[str]: - return ["latents"] + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents + def prepare_latents_img2img( + self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) - @property - def intermediates_outputs(self) -> List[str]: - return ["images"] + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) - def __init__(self, vae=None, image_processor=None, vae_scale_factor=8): - if image_processor is None: - image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) - super().__init__(vae=vae, image_processor=image_processor, vae_scale_factor=vae_scale_factor) + # Offload text encoder if `enable_model_cpu_offload` was enabled + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.text_encoder_2.to("cpu") + torch.cuda.empty_cache() - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - output_type = state.get_input("output_type") - return_dict = state.get_input("return_dict") + image = image.to(device=device, dtype=dtype) - latents = state.get_intermediate("latents") + batch_size = batch_size * num_images_per_prompt - if not output_type == "latent": - # make sure the VAE is in float32 mode, as it overflows in float16 - needs_upcasting = pipeline.vae.dtype == torch.float16 and pipeline.vae.config.force_upcast + if image.shape[1] == 4: + init_latents = image - if needs_upcasting: - pipeline.upcast_vae() - latents = latents.to(next(iter(pipeline.vae.post_quant_conv.parameters())).dtype) - elif latents.dtype != pipeline.vae.dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - pipeline.vae = pipeline.vae.to(latents.dtype) + else: + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) - # unscale/denormalize the latents - # denormalize with the mean and std if available and not None - has_latents_mean = ( - hasattr(pipeline.vae.config, "latents_mean") and pipeline.vae.config.latents_mean is not None - ) - has_latents_std = ( - hasattr(pipeline.vae.config, "latents_std") and pipeline.vae.config.latents_std is not None - ) - if has_latents_mean and has_latents_std: - latents_mean = ( - torch.tensor(pipeline.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) - ) - latents_std = ( - torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - latents = latents * latents_std / pipeline.vae.config.scaling_factor + latents_mean - else: - latents = latents / pipeline.vae.config.scaling_factor - image = pipeline.vae.decode(latents, return_dict=False)[0] + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) - # cast back to fp16 if needed - if needs_upcasting: - pipeline.vae.to(dtype=torch.float16) - else: - image = latents + init_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) - # apply watermark if available - if hasattr(pipeline, "watermark") and pipeline.watermark is not None: - image = pipeline.watermark.apply_watermark(image) + if self.vae.config.force_upcast: + self.vae.to(dtype) - image = pipeline.image_processor.postprocess(image, output_type=output_type) + init_latents = init_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=device, dtype=dtype) + latents_std = latents_std.to(device=device, dtype=dtype) + init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std + else: + init_latents = self.vae.config.scaling_factor * init_latents - if not return_dict: - output = (image,) + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) else: - output = StableDiffusionXLPipelineOutput(images=image) + init_latents = torch.cat([init_latents], dim=0) - state.add_intermediate("images", image) - state.add_output("images", output) + if add_noise: + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) - return pipeline, state + latents = init_latents + return latents -class PipelineBlockType(Enum): - InputStep = 1 - TextEncoderStep = 2 - SetTimestepsStep = 3 - PrepareLatentsStep = 4 - PrepareAdditionalConditioningStep = 5 - PrepareGuidance = 6 - DenoiseStep = 7 - DecodeLatentsStep = 8 - - -PIPELINE_BLOCKS = { - StableDiffusionXLPipeline: [ - PipelineBlockType.InputStep, - PipelineBlockType.TextEncoderStep, - PipelineBlockType.SetTimestepsStep, - PipelineBlockType.PrepareLatentsStep, - PipelineBlockType.PrepareAdditionalConditioningStep, - PipelineBlockType.PrepareGuidance, - PipelineBlockType.DenoiseStep, - PipelineBlockType.DecodeLatentsStep, - ], -} - - -class CustomPipelineBuilder: - def __init__(self, pipeline_class: str): - if pipeline_class == "SDXL": - self.pipeline = SDXLCustomPipeline() - else: - raise ValueError(f"Pipeline class {pipeline_class} not supported") - self.pipeline_blocks = [] - self.pipeline.builder = self - - def add_blocks(self, pipeline_blocks: Union[PipelineBlock, List[PipelineBlock]]): - if not isinstance(pipeline_blocks, list): - pipeline_blocks = [pipeline_blocks] - - for block in pipeline_blocks: - self.pipeline_blocks.append(block) - # filter out components that already exist in the pipeline - components_to_register = {} - for k, v in block.components.items(): - if not hasattr(self.pipeline, k) or v is not None: - components_to_register[k] = v - self.pipeline.register_modules(**components_to_register) - self.pipeline.register_to_config(**block.configs) - # Add auxiliaries as attributes to the pipeline - for key, value in block.auxiliaries.items(): - setattr(self.pipeline, key, value) - - for required_component in block.required_components: - if ( - not hasattr(self.pipeline, required_component) - or getattr(self.pipeline, required_component) is None - ): - raise ValueError( - f"Cannot add block {block.__class__.__name__}: Required component {required_component} not found in pipeline" - ) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] - for required_auxiliary in block.required_auxiliaries: - if ( - not hasattr(self.pipeline, required_auxiliary) - or getattr(self.pipeline, required_auxiliary) is None - ): - raise ValueError( - f"Cannot add block {block.__class__.__name__}: Required auxiliary {required_auxiliary} not found in pipeline" - ) + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta - def run_blocks(self, state: PipelineState = None, **kwargs): - """ - Run one or more blocks in sequence, optionally you can pass a previous pipeline state. + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: """ - if state is None: - state = PipelineState() - - pipeline = self.pipeline - - # Make a copy of the input kwargs - input_params = kwargs.copy() - - default_params = self.default_call_parameters - - # user can pass the intermediate of the first block - for name in self.pipeline_blocks[0].intermediates_inputs: - if name in input_params: - state.add_intermediate(name, input_params.pop(name)) - - # Add inputs to state, using defaults if not provided in the kwargs or the state - # if same input already in the state, will override it if provided in the kwargs - for name, default in default_params.items(): - if name in input_params: - state.add_input(name, input_params.pop(name)) - elif name not in state.inputs: - state.add_input(name, default) - - # Warn about unexpected inputs - if len(input_params) > 0: - logger.warning(f"Unexpected input '{input_params.keys()}' provided. This input will be ignored.") - # Run the pipeline - with torch.no_grad(): - for block in self.pipeline_blocks: - try: - pipeline, state = block(pipeline, state) - except Exception: - error_msg = f"Error in block: ({block.__class__.__name__}):\n" - logger.error(error_msg) - raise - - return state - - def run_pipeline(self, **kwargs): - state = PipelineState() - pipeline = self.pipeline - - # Make a copy of the input kwargs - input_params = kwargs.copy() - - default_params = self.default_call_parameters - - # Add inputs to state, using defaults if not provided - for name, default in default_params.items(): - if name in input_params: - state.add_input(name, input_params.pop(name)) - else: - state.add_input(name, default) + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 - # Warn about unexpected inputs - if len(input_params) > 0: - logger.warning(f"Unexpected input '{input_params.keys()}' provided. This input will be ignored.") + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. - # Run the pipeline - with torch.no_grad(): - for block in self.pipeline_blocks: - try: - pipeline, state = block(pipeline, state) - except Exception: - error_msg = f"Error in block: ({block.__class__.__name__}):\n" - logger.error(error_msg) - raise + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 - return state.get_output("images") + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb - @property - def default_call_parameters(self) -> Dict[str, Any]: - params = {} - for block in self.pipeline_blocks: - for name, default in block.inputs: - if name not in params: - params[name] = default - return params - - def __repr__(self): - output = "CustomPipeline Configuration:\n" - output += "==============================\n\n" - - # List the blocks used to build the pipeline - output += "Pipeline Blocks:\n" - output += "----------------\n" - for i, block in enumerate(self.pipeline_blocks, 1): - output += f"{i}. {block.__class__.__name__}\n" - - intermediates_str = "" - if hasattr(block, "intermediates_inputs"): - intermediates_str += f"{', '.join(block.intermediates_inputs)}" - - if hasattr(block, "intermediates_outputs"): - if intermediates_str: - intermediates_str += " -> " - else: - intermediates_str += "-> " - intermediates_str += f"{', '.join(block.intermediates_outputs)}" - - if intermediates_str: - output += f" {intermediates_str}\n" - - output += "\n" - output += "\n" - - # List the components registered in the pipeline - output += "Registered Components:\n" - output += "----------------------\n" - for name, component in self.pipeline.components.items(): - output += f"{name}: {type(component).__name__}\n" - output += "\n" - - # List the default call parameters - output += "Default Call Parameters:\n" - output += "------------------------\n" - params = self.default_call_parameters - for name, default in params.items(): - output += f"{name}: {default!r}\n" - - # Add a section for required call parameters: - # intermediate inputs for the first block - output += "\nRequired Call Parameters:\n" - output += "--------------------------\n" - for name in self.pipeline_blocks[0].intermediates_inputs: - output += f"{name}: \n" - params[name] = "" - - output += "\nNote: These are the default values. Actual values may be different when running the pipeline." - return output From c70a285c2cd8bb0623494398ab90b8f4797efe08 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 30 Oct 2024 10:33:25 +0100 Subject: [PATCH 017/170] style --- src/diffusers/__init__.py | 4 +- src/diffusers/pipelines/auto_pipeline.py | 2 +- .../pipelines/modular_pipeline_builder.py | 110 +++++++----------- .../pipelines/stable_diffusion_xl/__init__.py | 4 +- .../pipeline_stable_diffusion_xl_modular.py | 16 +-- 5 files changed, 55 insertions(+), 81 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index e2285f548c2f..b686d5e0edd8 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -145,12 +145,12 @@ "DDIMPipeline", "DDPMPipeline", "DiffusionPipeline", - "ModularPipelineBuilder", "DiTPipeline", "ImagePipelineOutput", "KarrasVePipeline", "LDMPipeline", "LDMSuperResolutionPipeline", + "ModularPipelineBuilder", "PNDMPipeline", "RePaintPipeline", "ScoreSdeVePipeline", @@ -366,11 +366,11 @@ "StableDiffusionXLImg2ImgPipeline", "StableDiffusionXLInpaintPipeline", "StableDiffusionXLInstructPix2PixPipeline", + "StableDiffusionXLModularPipeline", "StableDiffusionXLPAGImg2ImgPipeline", "StableDiffusionXLPAGInpaintPipeline", "StableDiffusionXLPAGPipeline", "StableDiffusionXLPipeline", - "StableDiffusionXLModularPipeline", "StableUnCLIPImg2ImgPipeline", "StableUnCLIPPipeline", "StableVideoDiffusionPipeline", diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index ac52024c7412..194f7fb5ae36 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -220,8 +220,8 @@ def _get_model(pipeline_class_name): if pipeline.__name__ == pipeline_class_name: return model_name -def _get_task_class(mapping, pipeline_class_name, throw_error_if_not_exist: bool = True): +def _get_task_class(mapping, pipeline_class_name, throw_error_if_not_exist: bool = True): model_name = _get_model(pipeline_class_name) if model_name is not None: diff --git a/src/diffusers/pipelines/modular_pipeline_builder.py b/src/diffusers/pipelines/modular_pipeline_builder.py index 98c2a9139e44..a39471455a15 100644 --- a/src/diffusers/pipelines/modular_pipeline_builder.py +++ b/src/diffusers/pipelines/modular_pipeline_builder.py @@ -14,31 +14,22 @@ import inspect from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Tuple, Union -import importlib -from collections import OrderedDict -import PIL +from typing import Any, Dict, List, Tuple, Union + import torch from tqdm.auto import tqdm from ..configuration_utils import ConfigMixin -from ..loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin -from ..models import ImageProjection -from ..models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor -from ..models.lora import adjust_lora_scale_text_encoder from ..utils import ( - USE_PEFT_BACKEND, is_accelerate_available, is_accelerate_version, logging, - scale_lora_layers, - unscale_lora_layers, ) from ..utils.hub_utils import validate_hf_hub_args -from ..utils.torch_utils import randn_tensor -from .pipeline_loading_utils import _fetch_class_library_tuple, _get_pipeline_class -from .pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .auto_pipeline import _get_model +from .pipeline_loading_utils import _fetch_class_library_tuple, _get_pipeline_class +from .pipeline_utils import DiffusionPipeline + if is_accelerate_available(): import accelerate @@ -51,7 +42,6 @@ } - @dataclass class PipelineState: """ @@ -225,6 +215,7 @@ class ModularPipelineBuilder(ConfigMixin): Base class for all Modular pipelines. """ + config_name = "model_index.json" model_cpu_offload_seq = None hf_device_map = None @@ -316,7 +307,7 @@ def components(self) -> Dict[str, Any]: expected_components = set() for block in self.pipeline_blocks: expected_components.update(block.components.keys()) - + components = {} for name in expected_components: if hasattr(self, name): @@ -349,8 +340,8 @@ def auxiliaries(self) -> Dict[str, Any]: @property def configs(self) -> Dict[str, Any]: r""" - The `self.configs` property returns all configs needed to initialize the pipeline, as defined by the - pipeline blocks. + The `self.configs` property returns all configs needed to initialize the pipeline, as defined by the pipeline + blocks. Returns (`dict`): A dictionary containing all the configs defined in the pipeline blocks. @@ -393,31 +384,32 @@ def __call__(self, *args, **kwargs): def remove_blocks(self, indices: Union[int, List[int]]): """ - Remove one or more blocks from the pipeline by their indices and clean up associated components, - configs, and auxiliaries that are no longer needed by remaining blocks. + Remove one or more blocks from the pipeline by their indices and clean up associated components, configs, and + auxiliaries that are no longer needed by remaining blocks. Args: indices (Union[int, List[int]]): The index or list of indices of blocks to remove """ # Convert single index to list indices = [indices] if isinstance(indices, int) else indices - + # Validate indices for idx in indices: if not 0 <= idx < len(self.pipeline_blocks): - raise ValueError(f"Invalid block index {idx}. Index must be between 0 and {len(self.pipeline_blocks) - 1}") - + raise ValueError( + f"Invalid block index {idx}. Index must be between 0 and {len(self.pipeline_blocks) - 1}" + ) + # Sort indices in descending order to avoid shifting issues when removing indices = sorted(indices, reverse=True) - + # Store blocks to be removed blocks_to_remove = [self.pipeline_blocks[idx] for idx in indices] - + # Remove blocks from pipeline for idx in indices: self.pipeline_blocks.pop(idx) - # Consolidate items to remove from all blocks components_to_remove = {k: v for block in blocks_to_remove for k, v in block.components.items()} auxiliaries_to_remove = {k: v for block in blocks_to_remove for k, v in block.auxiliaries.items()} @@ -448,7 +440,7 @@ def remove_blocks(self, indices: Union[int, List[int]]): def add_blocks(self, pipeline_blocks, at: int = -1): """Add blocks to the pipeline. - + Args: pipeline_blocks: A single PipelineBlock instance or a list of PipelineBlock instances. at (int, optional): Index at which to insert the blocks. Defaults to -1 (append at end). @@ -456,7 +448,7 @@ def add_blocks(self, pipeline_blocks, at: int = -1): # Convert single block to list for uniform processing if not isinstance(pipeline_blocks, (list, tuple)): pipeline_blocks = [pipeline_blocks] - + # Validate insert_at index if at != -1 and not 0 <= at <= len(self.pipeline_blocks): raise ValueError(f"Invalid at index {at}. Index must be between 0 and {len(self.pipeline_blocks)}") @@ -465,7 +457,7 @@ def add_blocks(self, pipeline_blocks, at: int = -1): components_to_add = {} configs_to_add = {} auxiliaries_to_add = {} - + # Add blocks in order for i, block in enumerate(pipeline_blocks): # Add block to pipeline at specified position @@ -473,16 +465,16 @@ def add_blocks(self, pipeline_blocks, at: int = -1): self.pipeline_blocks.append(block) else: self.pipeline_blocks.insert(at + i, block) - + # Collect components that don't already exist for k, v in block.components.items(): if not hasattr(self, k) or (getattr(self, k, None) is None and v is not None): components_to_add[k] = v - + # Collect configs and auxiliaries configs_to_add.update(block.configs) auxiliaries_to_add.update(block.auxiliaries) - + # Validate all required components and auxiliaries after consolidation for block in pipeline_blocks: for required_component in block.required_components: @@ -513,44 +505,37 @@ def add_blocks(self, pipeline_blocks, at: int = -1): if configs_to_add: self.register_to_config(**configs_to_add) for key, value in auxiliaries_to_add.items(): - setattr(self, key, value) def replace_blocks(self, pipeline_blocks, at: int): """Replace one or more blocks in the pipeline at the specified index. - + Args: - pipeline_blocks: A single PipelineBlock instance or a list of PipelineBlock instances + pipeline_blocks: A single PipelineBlock instance or a list of PipelineBlock instances that will replace existing blocks. at (int): Index at which to replace the blocks. """ # Convert single block to list for uniform processing if not isinstance(pipeline_blocks, (list, tuple)): pipeline_blocks = [pipeline_blocks] - + # Validate replace_at index if not 0 <= at < len(self.pipeline_blocks): - raise ValueError( - f"Invalid at index {at}. Index must be between 0 and {len(self.pipeline_blocks) - 1}" - ) - + raise ValueError(f"Invalid at index {at}. Index must be between 0 and {len(self.pipeline_blocks) - 1}") + # Add new blocks first self.add_blocks(pipeline_blocks, at=at) - + # Calculate indices to remove # We need to remove the original blocks that are now shifted by the length of pipeline_blocks - indices_to_remove = list(range( - at + len(pipeline_blocks), - at + len(pipeline_blocks) * 2 - )) - + indices_to_remove = list(range(at + len(pipeline_blocks), at + len(pipeline_blocks) * 2)) + # Remove the old blocks self.remove_blocks(indices_to_remove) @classmethod @validate_hf_hub_args def from_pretrained(cls, pretrained_model_or_path, **kwargs): - # (1) create the base pipeline cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) @@ -579,11 +564,9 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): modular_pipeline_class_name = MODULAR_PIPELINE_MAPPING[_get_model(base_pipeline_class_name)] modular_pipeline_class = _get_pipeline_class(cls, config=None, class_name=modular_pipeline_class_name) - # (3) create the pipeline blocks pipeline_blocks = [ - block_class.from_pipe(base_pipeline) - for block_class in modular_pipeline_class.default_pipeline_blocks + block_class.from_pipe(base_pipeline) for block_class in modular_pipeline_class.default_pipeline_blocks ] # (4) create the builder @@ -591,35 +574,31 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): builder.add_blocks(pipeline_blocks) return builder - + @classmethod def from_pipe(cls, pipeline, **kwargs): base_pipeline_class_name = pipeline.__class__.__name__ modular_pipeline_class_name = MODULAR_PIPELINE_MAPPING[_get_model(base_pipeline_class_name)] modular_pipeline_class = _get_pipeline_class(cls, config=None, class_name=modular_pipeline_class_name) - + pipeline_blocks = [] # Create each block, passing only unused items that the block expects for block_class in modular_pipeline_class.default_pipeline_blocks: expected_components = set(block_class.required_components + block_class.optional_components) expected_auxiliaries = set(block_class.required_auxiliaries) - + # Get init parameters to check for expected configs init_params = inspect.signature(block_class.__init__).parameters expected_configs = { - k for k in init_params - if k not in expected_components - and k not in expected_auxiliaries + k for k in init_params if k not in expected_components and k not in expected_auxiliaries } - + block_kwargs = {} - + for key, value in kwargs.items(): - if (key in expected_components or - key in expected_auxiliaries or - key in expected_configs): + if key in expected_components or key in expected_auxiliaries or key in expected_configs: block_kwargs[key] = value - + # Create the block with filtered kwargs block = block_class.from_pipe(pipeline, **block_kwargs) pipeline_blocks.append(block) @@ -630,10 +609,10 @@ def from_pipe(cls, pipeline, **kwargs): # Warn about unused kwargs unused_kwargs = { - k: v for k, v in kwargs.items() + k: v + for k, v in kwargs.items() if not any( - k in block.components or k in block.auxiliaries or k in block.configs - for block in pipeline_blocks + k in block.components or k in block.auxiliaries or k in block.configs for block in pipeline_blocks ) } if unused_kwargs: @@ -774,7 +753,6 @@ def __repr__(self): output += f"{name}: {config!r}\n" output += "\n" - # List the default call parameters output += "Default Call Parameters:\n" output += "------------------------\n" diff --git a/src/diffusers/pipelines/stable_diffusion_xl/__init__.py b/src/diffusers/pipelines/stable_diffusion_xl/__init__.py index ab5c6bde7d54..a1b821d1726f 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/__init__.py @@ -30,6 +30,7 @@ _import_structure["pipeline_stable_diffusion_xl_inpaint"] = ["StableDiffusionXLInpaintPipeline"] _import_structure["pipeline_stable_diffusion_xl_instruct_pix2pix"] = ["StableDiffusionXLInstructPix2PixPipeline"] _import_structure["pipeline_stable_diffusion_xl_modular"] = [ + "StableDiffusionXLControlNetDenoiseStep", "StableDiffusionXLDecodeLatentsStep", "StableDiffusionXLDenoiseStep", "StableDiffusionXLInputStep", @@ -38,7 +39,6 @@ "StableDiffusionXLPrepareLatentsStep", "StableDiffusionXLSetTimestepsStep", "StableDiffusionXLTextEncoderStep", - "StableDiffusionXLControlNetDenoiseStep", ] if is_transformers_available() and is_flax_available(): @@ -60,6 +60,7 @@ from .pipeline_stable_diffusion_xl_inpaint import StableDiffusionXLInpaintPipeline from .pipeline_stable_diffusion_xl_instruct_pix2pix import StableDiffusionXLInstructPix2PixPipeline from .pipeline_stable_diffusion_xl_modular import ( + StableDiffusionXLControlNetDenoiseStep, StableDiffusionXLDecodeLatentsStep, StableDiffusionXLDenoiseStep, StableDiffusionXLInputStep, @@ -68,7 +69,6 @@ StableDiffusionXLPrepareLatentsStep, StableDiffusionXLSetTimestepsStep, StableDiffusionXLTextEncoderStep, - StableDiffusionXLControlNetDenoiseStep, ) try: diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 061b6a8dd0f5..67b7d2c78d9c 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -43,7 +43,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, @@ -118,8 +117,6 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") - - class StableDiffusionXLInputStep(PipelineBlock): @property def inputs(self) -> List[Tuple[str, Any]]: @@ -837,7 +834,9 @@ def __init__( if controlnet_guider is None: controlnet_guider = CFGGuider() if control_image_processor is None: - control_image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor, do_convert_rgb=True, do_normalize=False) + control_image_processor = VaeImageProcessor( + vae_scale_factor=vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) super().__init__( unet=unet, controlnet=controlnet, @@ -1155,24 +1154,22 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: return pipeline, state - class StableDiffusionXLModularPipeline( ModularPipelineBuilder, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, ): - default_pipeline_blocks = [ StableDiffusionXLInputStep, - StableDiffusionXLTextEncoderStep, + StableDiffusionXLTextEncoderStep, StableDiffusionXLSetTimestepsStep, StableDiffusionXLPrepareLatentsStep, StableDiffusionXLPrepareAdditionalConditioningStep, StableDiffusionXLDenoiseStep, - StableDiffusionXLDecodeLatentsStep + StableDiffusionXLDecodeLatentsStep, ] - + def __init__(self): super().__init__() @@ -1817,4 +1814,3 @@ def get_guidance_scale_embedding( emb = torch.nn.functional.pad(emb, (0, 1)) assert emb.shape == (w.shape[0], embedding_dim) return emb - From ffc2992fc2c6a6e3c600c41f0b605d26a242e240 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 16 Nov 2024 22:42:06 +0100 Subject: [PATCH 018/170] add autostep (not complete) --- .../pipelines/modular_pipeline_builder.py | 199 +++++++++++++++++- .../pipeline_stable_diffusion_xl_modular.py | 22 +- 2 files changed, 209 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/modular_pipeline_builder.py b/src/diffusers/pipelines/modular_pipeline_builder.py index a39471455a15..c9f1ca6f082e 100644 --- a/src/diffusers/pipelines/modular_pipeline_builder.py +++ b/src/diffusers/pipelines/modular_pipeline_builder.py @@ -14,7 +14,8 @@ import inspect from dataclasses import dataclass, field -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Tuple, Union, Type +from collections import OrderedDict import torch from tqdm.auto import tqdm @@ -30,6 +31,8 @@ from .pipeline_loading_utils import _fetch_class_library_tuple, _get_pipeline_class from .pipeline_utils import DiffusionPipeline +import warnings + if is_accelerate_available(): import accelerate @@ -99,6 +102,7 @@ class PipelineBlock: optional_components = [] required_components = [] required_auxiliaries = [] + optional_auxiliaries = [] @property def inputs(self) -> Tuple[Tuple[str, Any], ...]: @@ -122,7 +126,7 @@ def __init__(self, **kwargs): for key, value in kwargs.items(): if key in self.required_components or key in self.optional_components: self.components[key] = value - elif key in self.required_auxiliaries: + elif key in self.required_auxiliaries or key in self.optional_auxiliaries: self.auxiliaries[key] = value else: self.configs[key] = value @@ -152,10 +156,11 @@ def from_pipe(cls, pipe: DiffusionPipeline, **kwargs): components_to_add[component_name] = component # add auxiliaries + expected_auxiliaries = set(cls.required_auxiliaries + cls.optional_auxiliaries) # - auxiliaries that are passed in kwargs - auxiliaries_to_add = {k: kwargs.pop(k) for k in cls.required_auxiliaries if k in kwargs} + auxiliaries_to_add = {k: kwargs.pop(k) for k in expected_auxiliaries if k in kwargs} # - auxiliaries that are in the pipeline - for aux_name in cls.required_auxiliaries: + for aux_name in expected_auxiliaries: if hasattr(pipe, aux_name) and aux_name not in auxiliaries_to_add: auxiliaries_to_add[aux_name] = getattr(pipe, aux_name) block_kwargs = {**components_to_add, **auxiliaries_to_add} @@ -167,7 +172,7 @@ def from_pipe(cls, pipe: DiffusionPipeline, **kwargs): expected_configs = { k for k in pipe.config.keys() - if k in init_params and k not in expected_components and k not in cls.required_auxiliaries + if k in init_params and k not in expected_components and k not in expected_auxiliaries } for config_name in expected_configs: @@ -210,6 +215,188 @@ def __repr__(self): ) +def combine_inputs(*input_lists: List[Tuple[str, Any]]) -> List[Tuple[str, Any]]: + """ + Combines multiple lists of (name, default_value) tuples. + For duplicate inputs, updates only if current value is None and new value is not None. + Warns if multiple non-None default values exist for the same input. + """ + combined_dict = {} + for inputs in input_lists: + for name, value in inputs: + if name in combined_dict: + current_value = combined_dict[name] + if current_value is not None and value is not None and current_value != value: + warnings.warn( + f"Multiple different default values found for input '{name}': " + f"{current_value} and {value}. Using {current_value}." + ) + if current_value is None and value is not None: + combined_dict[name] = value + else: + combined_dict[name] = value + return list(combined_dict.items()) + + + +class AutoStep(PipelineBlock): + base_blocks = [] # list of block classes + trigger_inputs = [] # list of trigger inputs (None for default block) + required_components = [] + optional_components = [] + required_auxiliaries = [] + optional_auxiliaries = [] + + def __init__(self, **kwargs): + self.blocks = [] + + for block_cls, trigger in zip(self.base_blocks, self.trigger_inputs): + # Check components + missing_components = [ + component for component in block_cls.required_components + if component not in kwargs + ] + + # Check auxiliaries + missing_auxiliaries = [ + auxiliary for auxiliary in block_cls.required_auxiliaries + if auxiliary not in kwargs + ] + + if not missing_components and not missing_auxiliaries: + # Only get kwargs that the block's __init__ accepts + block_params = inspect.signature(block_cls.__init__).parameters + block_kwargs = { + k: v for k, v in kwargs.items() + if k in block_params + } + self.blocks.append(block_cls(**block_kwargs)) + + # Print message about trigger condition + if trigger is None: + print(f"Added default block: {block_cls.__name__}") + else: + print(f"Added block {block_cls.__name__} - will be dispatched if '{trigger}' input is not None") + else: + if trigger is None: + print(f"Cannot add default block {block_cls.__name__}:") + else: + print(f"Cannot add block {block_cls.__name__} (triggered by '{trigger}'):") + if missing_components: + print(f" - Missing components: {missing_components}") + if missing_auxiliaries: + print(f" - Missing auxiliaries: {missing_auxiliaries}") + + @property + def components(self): + # Combine components from all blocks + components = {} + for block in self.blocks: + components.update(block.components) + return components + + @property + def auxiliaries(self): + # Combine auxiliaries from all blocks + auxiliaries = {} + for block in self.blocks: + auxiliaries.update(block.auxiliaries) + return auxiliaries + + @property + def configs(self): + # Combine configs from all blocks + configs = {} + for block in self.blocks: + configs.update(block.configs) + return configs + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return combine_inputs(*(block.inputs for block in self.blocks)) + + @property + def intermediates_inputs(self) -> List[str]: + return list(set().union(*( + block.intermediates_inputs for block in self.blocks + ))) + + @property + def intermediates_outputs(self) -> List[str]: + return list(set().union(*( + block.intermediates_outputs for block in self.blocks + ))) + + def __call__(self, pipeline, state): + # Check triggers in priority order + for idx, trigger in enumerate(self.trigger_inputs[:-1]): # Skip last (None) trigger + if state.get_input(trigger) is not None: + return self.blocks[idx](pipeline, state) + # If no triggers match, use the default block (last one) + return self.blocks[-1](pipeline, state) + + +def make_auto_step(pipeline_block_map: OrderedDict) -> Type[AutoStep]: + """ + Creates a new AutoStep subclass with updated class attributes based on the pipeline block map. + + Args: + pipeline_block_map: OrderedDict mapping trigger inputs to pipeline block classes. + Order determines priority (earlier entries take precedence). + Must include None key for the default block. + """ + blocks = list(pipeline_block_map.values()) + triggers = list(pipeline_block_map.keys()) + + # Get all expected components (either required or optional by any block) + expected_components = [] + for block in blocks: + for component in (block.required_components + block.optional_components): + if component not in expected_components: + expected_components.append(component) + + # A component is required if it's in required_components of all blocks + required_components = [ + component for component in expected_components + if all(component in block.required_components for block in blocks) + ] + + # All other expected components are optional + optional_components = [ + component for component in expected_components + if component not in required_components + ] + + # Get all expected auxiliaries (either required or optional by any block) + expected_auxiliaries = [] + for block in blocks: + for auxiliary in (block.required_auxiliaries + getattr(block, 'optional_auxiliaries', [])): + if auxiliary not in expected_auxiliaries: + expected_auxiliaries.append(auxiliary) + + # An auxiliary is required if it's in required_auxiliaries of all blocks + required_auxiliaries = [ + auxiliary for auxiliary in expected_auxiliaries + if all(auxiliary in block.required_auxiliaries for block in blocks) + ] + + # All other expected auxiliaries are optional + optional_auxiliaries = [ + auxiliary for auxiliary in expected_auxiliaries + if auxiliary not in required_auxiliaries + ] + + # Create new class with updated attributes + return type('AutoStep', (AutoStep,), { + 'base_blocks': blocks, + 'trigger_inputs': triggers, + 'required_components': required_components, + 'optional_components': optional_components, + 'required_auxiliaries': required_auxiliaries, + 'optional_auxiliaries': optional_auxiliaries, + }) + + class ModularPipelineBuilder(ConfigMixin): """ Base class for all Modular pipelines. @@ -585,7 +772,7 @@ def from_pipe(cls, pipeline, **kwargs): # Create each block, passing only unused items that the block expects for block_class in modular_pipeline_class.default_pipeline_blocks: expected_components = set(block_class.required_components + block_class.optional_components) - expected_auxiliaries = set(block_class.required_auxiliaries) + expected_auxiliaries = set(block_class.required_auxiliaries + block_class.optional_auxiliaries) # Get init parameters to check for expected configs init_params = inspect.signature(block_class.__init__).parameters diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 67b7d2c78d9c..46d0491b11cc 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -14,6 +14,7 @@ import inspect from typing import Any, List, Optional, Tuple, Union +from collections import OrderedDict import PIL import torch @@ -33,7 +34,7 @@ ) from ...utils.torch_utils import is_compiled_module, randn_tensor from ..controlnet.multicontrolnet import MultiControlNetModel -from ..modular_pipeline_builder import ModularPipelineBuilder, PipelineBlock, PipelineState +from ..modular_pipeline_builder import ModularPipelineBuilder, PipelineBlock, PipelineState, make_auto_step from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import ( StableDiffusionXLPipelineOutput, @@ -401,7 +402,7 @@ def denoising_value_valid(dnv): class StableDiffusionXLPrepareLatentsStep(PipelineBlock): optional_components = ["vae", "scheduler"] - required_auxiliaries = ["image_processor"] + optional_auxiliaries = ["image_processor"] @property def inputs(self) -> List[Tuple[str, Any]]: @@ -645,7 +646,7 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin class StableDiffusionXLDenoiseStep(PipelineBlock): required_components = ["unet", "scheduler"] - required_auxiliaries = ["guider"] + optional_auxiliaries = ["guider"] @property def inputs(self) -> List[Tuple[str, Any]]: @@ -780,7 +781,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): required_components = ["unet", "controlnet", "scheduler"] - required_auxiliaries = ["guider", "controlnet_guider", "control_image_processor"] + optional_auxiliaries = ["guider", "controlnet_guider", "control_image_processor"] @property def inputs(self) -> List[Tuple[str, Any]]: @@ -1069,7 +1070,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLDecodeLatentsStep(PipelineBlock): optional_components = ["vae"] - required_auxiliaries = ["image_processor"] + optional_auxiliaries = ["image_processor"] @property def inputs(self) -> List[Tuple[str, Any]]: @@ -1154,6 +1155,15 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: return pipeline, state +AUTO_DENOISE_BLOCK_MAP = OrderedDict([ + # Higher priority blocks first + ("control_image", StableDiffusionXLControlNetDenoiseStep), + # Default block + (None, StableDiffusionXLDenoiseStep), +]) + +StableDiffusionXLAutoDenoiseStep = make_auto_step(AUTO_DENOISE_BLOCK_MAP) + class StableDiffusionXLModularPipeline( ModularPipelineBuilder, StableDiffusionMixin, @@ -1166,7 +1176,7 @@ class StableDiffusionXLModularPipeline( StableDiffusionXLSetTimestepsStep, StableDiffusionXLPrepareLatentsStep, StableDiffusionXLPrepareAdditionalConditioningStep, - StableDiffusionXLDenoiseStep, + StableDiffusionXLAutoDenoiseStep, StableDiffusionXLDecodeLatentsStep, ] From ace53e2d2fa3f487714f333660bb21d95f2e058b Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 10 Dec 2024 03:41:28 +0100 Subject: [PATCH 019/170] update/refactor --- src/diffusers/guider.py | 8 +- .../pipelines/modular_pipeline_builder.py | 916 ++++++++++++------ .../pipeline_stable_diffusion_xl_modular.py | 490 +++++++--- src/diffusers/utils/dummy_pt_objects.py | 15 + .../dummy_torch_and_transformers_objects.py | 15 + 5 files changed, 970 insertions(+), 474 deletions(-) diff --git a/src/diffusers/guider.py b/src/diffusers/guider.py index 5369bd25f9e2..96dced267baa 100644 --- a/src/diffusers/guider.py +++ b/src/diffusers/guider.py @@ -351,9 +351,13 @@ def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]): ) def reset_guider(self, pipeline): - if self.do_perturbed_attention_guidance: + if ( + self.do_perturbed_attention_guidance + and hasattr(self, "original_attn_proc") + and self.original_attn_proc is not None + ): pipeline.unet.set_attn_processor(self.original_attn_proc) - pipeline.original_attn_proc = None + self.original_attn_proc = None def maybe_update_guider(self, pipeline, timestep): pass diff --git a/src/diffusers/pipelines/modular_pipeline_builder.py b/src/diffusers/pipelines/modular_pipeline_builder.py index c9f1ca6f082e..590165e170c5 100644 --- a/src/diffusers/pipelines/modular_pipeline_builder.py +++ b/src/diffusers/pipelines/modular_pipeline_builder.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect -from dataclasses import dataclass, field -from typing import Any, Dict, List, Tuple, Union, Type +import traceback +import warnings from collections import OrderedDict +from dataclasses import dataclass, field +from typing import Any, Dict, List, Tuple, Union import torch from tqdm.auto import tqdm @@ -27,12 +28,9 @@ logging, ) from ..utils.hub_utils import validate_hf_hub_args -from .auto_pipeline import _get_model -from .pipeline_loading_utils import _fetch_class_library_tuple, _get_pipeline_class +from .pipeline_loading_utils import _fetch_class_library_tuple from .pipeline_utils import DiffusionPipeline -import warnings - if is_accelerate_available(): import accelerate @@ -99,10 +97,9 @@ def format_value(v): class PipelineBlock: - optional_components = [] - required_components = [] - required_auxiliaries = [] - optional_auxiliaries = [] + expected_components = [] + expected_auxiliaries = [] + expected_configs = [] @property def inputs(self) -> Tuple[Tuple[str, Any], ...]: @@ -117,88 +114,117 @@ def intermediates_inputs(self) -> List[str]: def intermediates_outputs(self) -> List[str]: return [] + def update_states(self, **kwargs): + """ + Update components and configs after instance creation. Auxiliaries (e.g. image_processor) should be defined for + each pipeline block, does not need to be updated by users. Logs if existing non-None states are being + overwritten. + + Args: + **kwargs: Keyword arguments containing components, or configs to add/update. + e.g. pipeline_block.update_states(unet=unet1, vae=None) + """ + # Add expected components + for component_name in self.expected_components: + if component_name in kwargs: + if component_name in self.components and self.components[component_name] is not None: + if id(self.components[component_name]) != id(kwargs[component_name]): + logger.info( + f"Overwriting existing component '{component_name}' " + f"(type: {type(self.components[component_name]).__name__}) " + f"with new value (type: {type(kwargs[component_name]).__name__})" + ) + self.components[component_name] = kwargs.pop(component_name) + + # Add expected configs + for config_name in self.expected_configs: + if config_name in kwargs: + if config_name in self.configs and self.configs[config_name] is not None: + if self.configs[config_name] != kwargs[config_name]: + logger.info( + f"Overwriting existing config '{config_name}' " + f"(value: {self.configs[config_name]}) " + f"with new value ({kwargs[config_name]})" + ) + self.configs[config_name] = kwargs.pop(config_name) + def __init__(self, **kwargs): self.components: Dict[str, Any] = {} self.auxiliaries: Dict[str, Any] = {} self.configs: Dict[str, Any] = {} - # Process kwargs - for key, value in kwargs.items(): - if key in self.required_components or key in self.optional_components: - self.components[key] = value - elif key in self.required_auxiliaries or key in self.optional_auxiliaries: - self.auxiliaries[key] = value - else: - self.configs[key] = value + self.update_states(**kwargs) - @classmethod - def from_pipe(cls, pipe: DiffusionPipeline, **kwargs): + # YiYi notes, does pipeline block need "states"? it is not going to be used on its own + # TODO: address existing components -> overwrite or not? currently overwrite + def add_states_from_pipe(self, pipe: DiffusionPipeline, **kwargs): """ - Create a PipelineBlock instance from a diffusion pipeline object. + add components/auxiliaries/configs from a diffusion pipeline object. Args: pipe: A `[DiffusionPipeline]` object. + **kwargs: Additional states to update, these take precedence over pipe values. Returns: - PipelineBlock: An instance initialized with the pipeline's components and configurations. + PipelineBlock: An instance loaded with the pipeline's components and configurations. """ - # add components - expected_components = set(cls.required_components + cls.optional_components) - # - components that are passed in kwargs - components_to_add = { - component_name: kwargs.pop(component_name) - for component_name in expected_components - if component_name in kwargs - } - # - components that are in the pipeline - for component_name, component in pipe.components.items(): - if component_name in expected_components and component_name not in components_to_add: - components_to_add[component_name] = component - - # add auxiliaries - expected_auxiliaries = set(cls.required_auxiliaries + cls.optional_auxiliaries) - # - auxiliaries that are passed in kwargs - auxiliaries_to_add = {k: kwargs.pop(k) for k in expected_auxiliaries if k in kwargs} - # - auxiliaries that are in the pipeline - for aux_name in expected_auxiliaries: - if hasattr(pipe, aux_name) and aux_name not in auxiliaries_to_add: - auxiliaries_to_add[aux_name] = getattr(pipe, aux_name) - block_kwargs = {**components_to_add, **auxiliaries_to_add} - - # add pipeline configs - init_params = inspect.signature(cls.__init__).parameters - # modules info are also registered in the config as tuples, e.g. {'tokenizer': ('transformers', 'CLIPTokenizer')} - # we need to exclude them for block_kwargs otherwise it will override the actual module - expected_configs = { - k - for k in pipe.config.keys() - if k in init_params and k not in expected_components and k not in expected_auxiliaries - } - - for config_name in expected_configs: - if config_name not in block_kwargs: - if config_name in kwargs: - # - configs that are passed in kwargs - block_kwargs[config_name] = kwargs.pop(config_name) - else: - # - configs that are in the pipeline - block_kwargs[config_name] = pipe.config[config_name] - - # Add any remaining relevant pipeline attributes - for attr_name in dir(pipe): - if attr_name not in block_kwargs and attr_name in init_params: - block_kwargs[attr_name] = getattr(pipe, attr_name) + states_to_update = {} + + # Get components - prefer kwargs over pipe values + for component_name in self.expected_components: + if component_name in kwargs: + states_to_update[component_name] = kwargs.pop(component_name) + elif component_name in pipe.components: + states_to_update[component_name] = pipe.components[component_name] + + # Get configs - prefer kwargs over pipe values + pipe_config = dict(pipe.config) + for config_name in self.expected_configs: + if config_name in kwargs: + states_to_update[config_name] = kwargs.pop(config_name) + elif config_name in pipe_config: + states_to_update[config_name] = pipe_config[config_name] + + # Update all states at once + self.update_states(**states_to_update) - return cls(**block_kwargs) + @validate_hf_hub_args + def add_states_from_pretrained(self, pretrained_model_or_path, **kwargs): + base_pipeline = DiffusionPipeline.from_pretrained(pretrained_model_or_path, **kwargs) + self.add_states_from_pipe(base_pipeline, **kwargs) def __call__(self, pipeline, state: PipelineState) -> PipelineState: raise NotImplementedError("__call__ method must be implemented in subclasses") def __repr__(self): class_name = self.__class__.__name__ - components = ", ".join(f"{k}={type(v).__name__}" for k, v in self.components.items()) - auxiliaries = ", ".join(f"{k}={type(v).__name__}" for k, v in self.auxiliaries.items()) - configs = ", ".join(f"{k}={v}" for k, v in self.configs.items()) + + # Components section + expected_components = set(getattr(self, "expected_components", [])) + loaded_components = set(self.components.keys()) + all_components = sorted(expected_components | loaded_components) + components = ", ".join( + f"{k}={type(self.components[k]).__name__}" if k in loaded_components else f"{k}" for k in all_components + ) + + # Auxiliaries section + expected_auxiliaries = set(getattr(self, "expected_auxiliaries", [])) + loaded_auxiliaries = set(self.auxiliaries.keys()) + all_auxiliaries = sorted(expected_auxiliaries | loaded_auxiliaries) + auxiliaries = ", ".join( + f"{k}={type(self.auxiliaries[k]).__name__}" if k in loaded_auxiliaries else f"{k}" for k in all_auxiliaries + ) + + # Configs section + expected_configs = set(getattr(self, "expected_configs", [])) + loaded_configs = set(self.configs.keys()) + all_configs = sorted(expected_configs | loaded_configs) + configs = ", ".join(f"{k}={self.configs[k]}" if k in loaded_configs else f"{k}" for k in all_configs) + + # Single block shows itself + blocks = f"step={self.__class__.__name__}" + + # Other information inputs = ", ".join(f"{name}={default}" for name, default in self.inputs) intermediates_inputs = ", ".join(self.intermediates_inputs) intermediates_outputs = ", ".join(self.intermediates_outputs) @@ -208,6 +234,7 @@ def __repr__(self): f" components: {components}\n" f" auxiliaries: {auxiliaries}\n" f" configs: {configs}\n" + f" blocks: {blocks}\n" f" inputs: {inputs}\n" f" intermediates_inputs: {intermediates_inputs}\n" f" intermediates_outputs: {intermediates_outputs}\n" @@ -217,9 +244,8 @@ def __repr__(self): def combine_inputs(*input_lists: List[Tuple[str, Any]]) -> List[Tuple[str, Any]]: """ - Combines multiple lists of (name, default_value) tuples. - For duplicate inputs, updates only if current value is None and new value is not None. - Warns if multiple non-None default values exist for the same input. + Combines multiple lists of (name, default_value) tuples. For duplicate inputs, updates only if current value is + None and new value is not None. Warns if multiple non-None default values exist for the same input. """ combined_dict = {} for inputs in input_lists: @@ -238,163 +264,513 @@ def combine_inputs(*input_lists: List[Tuple[str, Any]]) -> List[Tuple[str, Any]] return list(combined_dict.items()) +class MultiPipelineBlocks: + """ + A class that combines multiple pipeline block classes into one. When used, it has same API and properties as + PipelineBlock. And it can be used in ModularPipelineBuilder as a single pipeline block. + """ + + block_classes = [] + block_names = [] + + @property + def expected_components(self): + expected_components = [] + for block in self.blocks.values(): + for component in block.expected_components: + if component not in expected_components: + expected_components.append(component) + return expected_components + + @property + def expected_auxiliaries(self): + expected_auxiliaries = [] + for block in self.blocks.values(): + for auxiliary in block.expected_auxiliaries: + if auxiliary not in expected_auxiliaries: + expected_auxiliaries.append(auxiliary) + return expected_auxiliaries + + @property + def expected_configs(self): + expected_configs = [] + for block in self.blocks.values(): + for config in block.expected_configs: + if config not in expected_configs: + expected_configs.append(config) + return expected_configs -class AutoStep(PipelineBlock): - base_blocks = [] # list of block classes - trigger_inputs = [] # list of trigger inputs (None for default block) - required_components = [] - optional_components = [] - required_auxiliaries = [] - optional_auxiliaries = [] - def __init__(self, **kwargs): - self.blocks = [] - - for block_cls, trigger in zip(self.base_blocks, self.trigger_inputs): - # Check components - missing_components = [ - component for component in block_cls.required_components - if component not in kwargs - ] - - # Check auxiliaries - missing_auxiliaries = [ - auxiliary for auxiliary in block_cls.required_auxiliaries - if auxiliary not in kwargs - ] - - if not missing_components and not missing_auxiliaries: - # Only get kwargs that the block's __init__ accepts - block_params = inspect.signature(block_cls.__init__).parameters - block_kwargs = { - k: v for k, v in kwargs.items() - if k in block_params - } - self.blocks.append(block_cls(**block_kwargs)) - - # Print message about trigger condition - if trigger is None: - print(f"Added default block: {block_cls.__name__}") - else: - print(f"Added block {block_cls.__name__} - will be dispatched if '{trigger}' input is not None") - else: - if trigger is None: - print(f"Cannot add default block {block_cls.__name__}:") - else: - print(f"Cannot add block {block_cls.__name__} (triggered by '{trigger}'):") - if missing_components: - print(f" - Missing components: {missing_components}") - if missing_auxiliaries: - print(f" - Missing auxiliaries: {missing_auxiliaries}") - + blocks = OrderedDict() + for block_prefix, block_cls in zip(self.block_prefixes, self.block_classes): + block_name = f"{block_prefix}_step" if block_prefix != "" else "step" + blocks[block_name] = block_cls(**kwargs) + self.blocks = blocks + + # YiYi TODO: address the case where multiple blocks have the same component/auxiliary/config; give out warning etc @property def components(self): # Combine components from all blocks components = {} - for block in self.blocks: + for block_name, block in self.blocks.items(): components.update(block.components) return components - + @property def auxiliaries(self): # Combine auxiliaries from all blocks auxiliaries = {} - for block in self.blocks: + for block_name, block in self.blocks.items(): auxiliaries.update(block.auxiliaries) return auxiliaries - + @property def configs(self): # Combine configs from all blocks configs = {} - for block in self.blocks: + for block_name, block in self.blocks.items(): configs.update(block.configs) return configs - + @property def inputs(self) -> List[Tuple[str, Any]]: - return combine_inputs(*(block.inputs for block in self.blocks)) - + raise NotImplementedError("inputs property must be implemented in subclasses") + @property def intermediates_inputs(self) -> List[str]: - return list(set().union(*( - block.intermediates_inputs for block in self.blocks - ))) - + raise NotImplementedError("intermediates_inputs property must be implemented in subclasses") + @property def intermediates_outputs(self) -> List[str]: - return list(set().union(*( - block.intermediates_outputs for block in self.blocks - ))) - + raise NotImplementedError("intermediates_outputs property must be implemented in subclasses") + def __call__(self, pipeline, state): - # Check triggers in priority order - for idx, trigger in enumerate(self.trigger_inputs[:-1]): # Skip last (None) trigger - if state.get_input(trigger) is not None: - return self.blocks[idx](pipeline, state) - # If no triggers match, use the default block (last one) - return self.blocks[-1](pipeline, state) + raise NotImplementedError("__call__ method must be implemented in subclasses") + + def update_states(self, **kwargs): + """ + Update states for each block with support for block-specific kwargs. + + Args: + **kwargs: Can include both general kwargs (e.g., 'unet') and + block-specific kwargs (e.g., 'img2img_step_unet') + + Example: + pipeline.update_states( + img2img_step_unet=unet2, # Only for img2img_step step_unet=unet1, # Only for step vae=vae1 # For any + block that expects vae + ) + """ + for block_name, block in self.blocks.items(): + # Prepare block-specific kwargs + if isinstance(block, PipelineBlock): + block_kwargs = {} + + # Check for block-specific kwargs first (e.g., 'img2img_unet') + prefix = f"{block_name.replace('_step', '')}_" + for key, value in kwargs.items(): + if key.startswith(prefix): + # Remove prefix and add to block kwargs + block_kwargs[key[len(prefix) :]] = value + + # For any expected component/auxiliary/config not found with prefix, + # fall back to general kwargs + for name in ( + block.expected_components + + + # block.expected_auxiliaries + + block.expected_configs + ): + if name not in block_kwargs: + if name in kwargs: + block_kwargs[name] = kwargs[name] + elif isinstance(block, MultiPipelineBlocks): + block_kwargs = kwargs + else: + raise ValueError(f"Unsupported block type: {type(block).__name__}") + # Update the block with its specific kwargs + block.update_states(**block_kwargs) -def make_auto_step(pipeline_block_map: OrderedDict) -> Type[AutoStep]: + def add_states_from_pipe(self, pipe: DiffusionPipeline, **kwargs): + """ + Load components from pipe with support for block-specific kwargs. + + Args: + pipe: DiffusionPipeline object + **kwargs: Can include both general kwargs (e.g., 'unet') and + block-specific kwargs (e.g., 'img2img_unet' for 'img2img_step') + """ + for block_name, block in self.blocks.items(): + # Handle different block types + if isinstance(block, PipelineBlock): + block_kwargs = {} + + # Check for block-specific kwargs first (e.g., 'img2img_unet') + prefix = f"{block_name.replace('_step', '')}_" + for key, value in kwargs.items(): + if key.startswith(prefix): + # Remove prefix and add to block kwargs + block_kwargs[key[len(prefix) :]] = value + + # For any expected component/auxiliary/config not found with prefix, + # fall back to general kwargs + for name in ( + block.expected_components + + + # block.expected_auxiliaries + + block.expected_configs + ): + if name not in block_kwargs: + if name in kwargs: + block_kwargs[name] = kwargs[name] + elif isinstance(block, MultiPipelineBlocks): + block_kwargs = kwargs + else: + raise ValueError(f"Unsupported block type: {type(block).__name__}") + + # Load the block with its specific kwargs + block.add_states_from_pipe(pipe, **block_kwargs) + + def add_states_from_pretrained(self, pretrained_model_or_path, **kwargs): + base_pipeline = DiffusionPipeline.from_pretrained(pretrained_model_or_path, **kwargs) + self.add_states_from_pipe(base_pipeline, **kwargs) + + def __repr__(self): + class_name = self.__class__.__name__ + + # Components section + expected_components = set(getattr(self, "expected_components", [])) + loaded_components = set(self.components.keys()) + all_components = sorted(expected_components | loaded_components) + components_str = " Components:\n" + "\n".join( + f" - {k}={type(self.components[k]).__name__}" if k in loaded_components else f" - {k}" + for k in all_components + ) + + # Auxiliaries section + expected_auxiliaries = set(getattr(self, "expected_auxiliaries", [])) + loaded_auxiliaries = set(self.auxiliaries.keys()) + all_auxiliaries = sorted(expected_auxiliaries | loaded_auxiliaries) + auxiliaries_str = " Auxiliaries:\n" + "\n".join( + f" - {k}={type(self.auxiliaries[k]).__name__}" if k in loaded_auxiliaries else f" - {k}" + for k in all_auxiliaries + ) + + # Configs section + expected_configs = set(getattr(self, "expected_configs", [])) + loaded_configs = set(self.configs.keys()) + all_configs = sorted(expected_configs | loaded_configs) + configs_str = " Configs:\n" + "\n".join( + f" - {k}={self.configs[k]}" if k in loaded_configs else f" - {k}" for k in all_configs + ) + + # Blocks section + blocks_str = " Blocks:\n" + "\n".join( + f" - {name}={block.__class__.__name__}" for name, block in self.blocks.items() + ) + + # Other information + inputs_str = " Inputs:\n" + "\n".join(f" - {name}={default}" for name, default in self.inputs) + + intermediates_str = ( + " Intermediates:\n" + f" - inputs: {', '.join(self.intermediates_inputs)}\n" + f" - outputs: {', '.join(self.intermediates_outputs)}" + ) + + return ( + f"{class_name}(\n" + f"{components_str}\n" + f"{auxiliaries_str}\n" + f"{configs_str}\n" + f"{blocks_str}\n" + f"{inputs_str}\n" + f"{intermediates_str}\n" + f")" + ) + + +class AutoPipelineBlocks(MultiPipelineBlocks): """ - Creates a new AutoStep subclass with updated class attributes based on the pipeline block map. - - Args: - pipeline_block_map: OrderedDict mapping trigger inputs to pipeline block classes. - Order determines priority (earlier entries take precedence). - Must include None key for the default block. + A class that automatically selects which block to run based on trigger inputs. + + Attributes: + block_classes: List of block classes to be used + block_prefixes: List of prefixes for each block + block_trigger_inputs: List of input names that trigger specific blocks, with None for default """ - blocks = list(pipeline_block_map.values()) - triggers = list(pipeline_block_map.keys()) - - # Get all expected components (either required or optional by any block) - expected_components = [] - for block in blocks: - for component in (block.required_components + block.optional_components): - if component not in expected_components: - expected_components.append(component) - - # A component is required if it's in required_components of all blocks - required_components = [ - component for component in expected_components - if all(component in block.required_components for block in blocks) - ] - - # All other expected components are optional - optional_components = [ - component for component in expected_components - if component not in required_components - ] - - # Get all expected auxiliaries (either required or optional by any block) - expected_auxiliaries = [] - for block in blocks: - for auxiliary in (block.required_auxiliaries + getattr(block, 'optional_auxiliaries', [])): - if auxiliary not in expected_auxiliaries: - expected_auxiliaries.append(auxiliary) - - # An auxiliary is required if it's in required_auxiliaries of all blocks - required_auxiliaries = [ - auxiliary for auxiliary in expected_auxiliaries - if all(auxiliary in block.required_auxiliaries for block in blocks) - ] - - # All other expected auxiliaries are optional - optional_auxiliaries = [ - auxiliary for auxiliary in expected_auxiliaries - if auxiliary not in required_auxiliaries - ] - - # Create new class with updated attributes - return type('AutoStep', (AutoStep,), { - 'base_blocks': blocks, - 'trigger_inputs': triggers, - 'required_components': required_components, - 'optional_components': optional_components, - 'required_auxiliaries': required_auxiliaries, - 'optional_auxiliaries': optional_auxiliaries, - }) + + block_classes = [] + block_prefixes = [] + block_trigger_inputs = [] + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.__post_init__() + + def __post_init__(self): + """ + Create mapping of trigger inputs directly to block objects. Validates that there is at most one default block + (None trigger). + """ + # Check for at most one default block + default_blocks = [t for t in self.block_trigger_inputs if t is None] + if len(default_blocks) > 1: + raise ValueError( + f"Multiple default blocks specified in {self.__class__.__name__}. " + "Must include at most one None in block_trigger_inputs." + ) + + # Map trigger inputs to block objects + self.trigger_to_block_map = dict(zip(self.block_trigger_inputs, self.blocks.values())) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return combine_inputs(*(block.inputs for block in self.blocks.values())) + + @property + def intermediates_inputs(self) -> List[str]: + return list(set().union(*(block.intermediates_inputs for block in self.blocks.values()))) + + @property + def intermediates_outputs(self) -> List[str]: + return list(set().union(*(block.intermediates_outputs for block in self.blocks.values()))) + + @torch.no_grad() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + # Find default block first (if any) + default_block = self.trigger_to_block_map.get(None) + + # Check which trigger inputs are present + active_triggers = [ + input_name + for input_name in self.block_trigger_inputs + if input_name is not None and state.get_input(input_name) is not None + ] + + # If multiple triggers are active, raise error + if len(active_triggers) > 1: + trigger_names = [f"'{t}'" for t in active_triggers] + raise ValueError( + f"Multiple trigger inputs found ({', '.join(trigger_names)}). " + f"Only one trigger input can be provided for {self.__class__.__name__}." + ) + + # Get the block to run (use default if no triggers active) + block = self.trigger_to_block_map.get(active_triggers[0]) if active_triggers else default_block + if block is None: + logger.warning(f"No valid block found in {self.__class__.__name__}, skipping.") + return pipeline, state + + try: + return block(pipeline, state) + except Exception as e: + error_msg = ( + f"\nError in block: {block.__class__.__name__}\n" + f"Error details: {str(e)}\n" + f"Traceback:\n{traceback.format_exc()}" + ) + logger.error(error_msg) + raise + + def __repr__(self): + class_name = self.__class__.__name__ + + # Components section + expected_components = set(getattr(self, "expected_components", [])) + loaded_components = set(self.components.keys()) + all_components = sorted(expected_components | loaded_components) + components_str = " Components:\n" + "\n".join( + f" - {k}={type(self.components[k]).__name__}" if k in loaded_components else f" - {k}" + for k in all_components + ) + + # Auxiliaries section + expected_auxiliaries = set(getattr(self, "expected_auxiliaries", [])) + loaded_auxiliaries = set(self.auxiliaries.keys()) + all_auxiliaries = sorted(expected_auxiliaries | loaded_auxiliaries) + auxiliaries_str = " Auxiliaries:\n" + "\n".join( + f" - {k}={type(self.auxiliaries[k]).__name__}" if k in loaded_auxiliaries else f" - {k}" + for k in all_auxiliaries + ) + + # Configs section + expected_configs = set(getattr(self, "expected_configs", [])) + loaded_configs = set(self.configs.keys()) + all_configs = sorted(expected_configs | loaded_configs) + configs_str = " Configs:\n" + "\n".join( + f" - {k}={self.configs[k]}" if k in loaded_configs else f" - {k}" for k in all_configs + ) + + # Blocks section with trigger information + blocks_str = " Blocks:\n" + for name, block in self.blocks.items(): + # Find trigger for this block + trigger = next((t for t, b in self.trigger_to_block_map.items() if b == block), None) + trigger_str = " (default)" if trigger is None else f" (triggered by: {trigger})" + + blocks_str += f" {name} ({block.__class__.__name__}){trigger_str}\n" + + # Add inputs information + if hasattr(block, "inputs"): + inputs_str = ", ".join(f"{name}={default}" for name, default in block.inputs) + if inputs_str: + blocks_str += f" inputs: {inputs_str}\n" + + # Add intermediates information + if hasattr(block, "intermediates_inputs") or hasattr(block, "intermediates_outputs"): + intermediates_str = "" + if hasattr(block, "intermediates_inputs"): + intermediates_str += f"{', '.join(block.intermediates_inputs)}" + + if hasattr(block, "intermediates_outputs"): + if intermediates_str: + intermediates_str += " -> " + intermediates_str += f"{', '.join(block.intermediates_outputs)}" + + if intermediates_str: + blocks_str += f" intermediates: {intermediates_str}\n" + blocks_str += "\n" + + # Pipeline interface information + inputs_str = " PipelineBlock Interface:\n" + inputs_str += " Inputs:\n" + "\n".join(f" - {name}={default}" for name, default in self.inputs) + + intermediates_str = ( + "\n Intermediates:\n" + f" - inputs: {', '.join(self.intermediates_inputs)}\n" + f" - outputs: {', '.join(self.intermediates_outputs)}" + ) + + return ( + f"{class_name}(\n" + f"{components_str}\n" + f"{auxiliaries_str}\n" + f"{configs_str}\n" + f"{blocks_str}\n" + f"{inputs_str}" + f"{intermediates_str}\n" + f")" + ) + + +class SequentialPipelineBlocks(MultiPipelineBlocks): + """ + A class that combines multiple pipeline block classes into one. When called, it will call each block in sequence. + """ + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return combine_inputs(*(block.inputs for block in self.blocks.values())) + + @property + def intermediates_inputs(self) -> List[str]: + inputs = set() + outputs = set() + + # Go through all blocks in order + for block in self.blocks.values(): + # Add inputs that aren't in outputs yet + inputs.update(input_name for input_name in block.intermediates_inputs if input_name not in outputs) + # Add this block's outputs + outputs.update(block.intermediates_outputs) + + return list(inputs) + + @property + def intermediates_outputs(self) -> List[str]: + return list(set().union(*(block.intermediates_outputs for block in self.blocks.values()))) + + @torch.no_grad() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + for block_name, block in self.blocks.items(): + try: + pipeline, state = block(pipeline, state) + except Exception as e: + error_msg = ( + f"\nError in block: ({block_name}, {block.__class__.__name__})\n" + f"Error details: {str(e)}\n" + f"Traceback:\n{traceback.format_exc()}" + ) + logger.error(error_msg) + raise + return pipeline, state + + def __repr__(self): + class_name = self.__class__.__name__ + + # Components section + expected_components = set(getattr(self, "expected_components", [])) + loaded_components = set(self.components.keys()) + all_components = sorted(expected_components | loaded_components) + components_str = " Components:\n" + "\n".join( + f" - {k}={type(self.components[k]).__name__}" if k in loaded_components else f" - {k}" + for k in all_components + ) + + # Auxiliaries section + expected_auxiliaries = set(getattr(self, "expected_auxiliaries", [])) + loaded_auxiliaries = set(self.auxiliaries.keys()) + all_auxiliaries = sorted(expected_auxiliaries | loaded_auxiliaries) + auxiliaries_str = " Auxiliaries:\n" + "\n".join( + f" - {k}={type(self.auxiliaries[k]).__name__}" if k in loaded_auxiliaries else f" - {k}" + for k in all_auxiliaries + ) + + # Configs section + expected_configs = set(getattr(self, "expected_configs", [])) + loaded_configs = set(self.configs.keys()) + all_configs = sorted(expected_configs | loaded_configs) + configs_str = " Configs:\n" + "\n".join( + f" - {k}={self.configs[k]}" if k in loaded_configs else f" - {k}" for k in all_configs + ) + + # Detailed blocks section with data flow + blocks_str = " Blocks:\n" + for i, (name, block) in enumerate(self.blocks.items()): + blocks_str += f" {i}. {name} ({block.__class__.__name__})\n" + + # Add inputs information + if hasattr(block, "inputs"): + inputs_str = ", ".join(f"{name}={default}" for name, default in block.inputs) + blocks_str += f" inputs: {inputs_str}\n" + + # Add intermediates information + if hasattr(block, "intermediates_inputs") or hasattr(block, "intermediates_outputs"): + intermediates_str = "" + if hasattr(block, "intermediates_inputs"): + intermediates_str += f"{', '.join(block.intermediates_inputs)}" + + if hasattr(block, "intermediates_outputs"): + if intermediates_str: + intermediates_str += " -> " + intermediates_str += f"{', '.join(block.intermediates_outputs)}" + + if intermediates_str: + blocks_str += f" intermediates: {intermediates_str}\n" + blocks_str += "\n" + + # Pipeline interface information + inputs_str = " PipelineBlock Interface:\n" + inputs_str += " Inputs:\n" + "\n".join(f" - {name}={default}" for name, default in self.inputs) + + intermediates_str = ( + "\n Intermediates:\n" + f" - inputs: {', '.join(self.intermediates_inputs)}\n" + f" - outputs: {', '.join(self.intermediates_outputs)}" + ) + + return ( + f"{class_name}(\n" + f"{components_str}\n" + f"{auxiliaries_str}\n" + f"{configs_str}\n" + f"{blocks_str}\n" + f"{inputs_str}" + f"{intermediates_str}\n" + f")" + ) class ModularPipelineBuilder(ConfigMixin): @@ -662,30 +1038,6 @@ def add_blocks(self, pipeline_blocks, at: int = -1): configs_to_add.update(block.configs) auxiliaries_to_add.update(block.auxiliaries) - # Validate all required components and auxiliaries after consolidation - for block in pipeline_blocks: - for required_component in block.required_components: - if ( - not hasattr(self, required_component) - and required_component not in components_to_add - or getattr(self, required_component, None) is None - and components_to_add.get(required_component) is None - ): - raise ValueError( - f"Cannot add block {block.__class__.__name__}: Required component {required_component} not found in pipeline" - ) - - for required_auxiliary in block.required_auxiliaries: - if ( - not hasattr(self, required_auxiliary) - and required_auxiliary not in auxiliaries_to_add - or getattr(self, required_auxiliary, None) is None - and auxiliaries_to_add.get(required_auxiliary) is None - ): - raise ValueError( - f"Cannot add block {block.__class__.__name__}: Required auxiliary {required_auxiliary} not found in pipeline" - ) - # Process all items in batches if components_to_add: self.register_modules(**components_to_add) @@ -720,95 +1072,6 @@ def replace_blocks(self, pipeline_blocks, at: int): # Remove the old blocks self.remove_blocks(indices_to_remove) - @classmethod - @validate_hf_hub_args - def from_pretrained(cls, pretrained_model_or_path, **kwargs): - # (1) create the base pipeline - cache_dir = kwargs.pop("cache_dir", None) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - token = kwargs.pop("token", None) - local_files_only = kwargs.pop("local_files_only", False) - revision = kwargs.pop("revision", None) - - load_config_kwargs = { - "cache_dir": cache_dir, - "force_download": force_download, - "proxies": proxies, - "token": token, - "local_files_only": local_files_only, - "revision": revision, - } - - config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) - base_pipeline_class_name = config["_class_name"] - base_pipeline_class = _get_pipeline_class(cls, config) - - kwargs = {**load_config_kwargs, **kwargs} - base_pipeline = base_pipeline_class.from_pretrained(pretrained_model_or_path, **kwargs) - - # (2) map the base pipeline to pipeline blocks - modular_pipeline_class_name = MODULAR_PIPELINE_MAPPING[_get_model(base_pipeline_class_name)] - modular_pipeline_class = _get_pipeline_class(cls, config=None, class_name=modular_pipeline_class_name) - - # (3) create the pipeline blocks - pipeline_blocks = [ - block_class.from_pipe(base_pipeline) for block_class in modular_pipeline_class.default_pipeline_blocks - ] - - # (4) create the builder - builder = modular_pipeline_class() - builder.add_blocks(pipeline_blocks) - - return builder - - @classmethod - def from_pipe(cls, pipeline, **kwargs): - base_pipeline_class_name = pipeline.__class__.__name__ - modular_pipeline_class_name = MODULAR_PIPELINE_MAPPING[_get_model(base_pipeline_class_name)] - modular_pipeline_class = _get_pipeline_class(cls, config=None, class_name=modular_pipeline_class_name) - - pipeline_blocks = [] - # Create each block, passing only unused items that the block expects - for block_class in modular_pipeline_class.default_pipeline_blocks: - expected_components = set(block_class.required_components + block_class.optional_components) - expected_auxiliaries = set(block_class.required_auxiliaries + block_class.optional_auxiliaries) - - # Get init parameters to check for expected configs - init_params = inspect.signature(block_class.__init__).parameters - expected_configs = { - k for k in init_params if k not in expected_components and k not in expected_auxiliaries - } - - block_kwargs = {} - - for key, value in kwargs.items(): - if key in expected_components or key in expected_auxiliaries or key in expected_configs: - block_kwargs[key] = value - - # Create the block with filtered kwargs - block = block_class.from_pipe(pipeline, **block_kwargs) - pipeline_blocks.append(block) - - # Create and setup the builder - builder = modular_pipeline_class() - builder.add_blocks(pipeline_blocks) - - # Warn about unused kwargs - unused_kwargs = { - k: v - for k, v in kwargs.items() - if not any( - k in block.components or k in block.auxiliaries or k in block.configs for block in pipeline_blocks - ) - } - if unused_kwargs: - logger.warning( - f"The following items were passed but not used by any pipeline block: {list(unused_kwargs.keys())}" - ) - - return builder - def run_blocks(self, state: PipelineState = None, **kwargs): """ Run one or more blocks in sequence, optionally you can pass a previous pipeline state. @@ -821,19 +1084,22 @@ def run_blocks(self, state: PipelineState = None, **kwargs): default_params = self.default_call_parameters - # user can pass the intermediate of the first block - for name in self.pipeline_blocks[0].intermediates_inputs: - if name in input_params: - state.add_intermediate(name, input_params.pop(name)) - # Add inputs to state, using defaults if not provided in the kwargs or the state # if same input already in the state, will override it if provided in the kwargs + for name, default in default_params.items(): if name in input_params: - state.add_input(name, input_params.pop(name)) + if name not in self.pipeline_blocks[0].intermediates_inputs: + state.add_input(name, input_params.pop(name)) + else: + state.add_input(name, input_params[name]) elif name not in state.inputs: state.add_input(name, default) + for name in self.pipeline_blocks[0].intermediates_inputs: + if name in input_params: + state.add_intermediate(name, input_params.pop(name)) + # Warn about unexpected inputs if len(input_params) > 0: logger.warning(f"Unexpected input '{input_params.keys()}' provided. This input will be ignored.") @@ -873,8 +1139,12 @@ def run_pipeline(self, **kwargs): for block in self.pipeline_blocks: try: pipeline, state = block(self, state) - except Exception: - error_msg = f"Error in block: ({block.__class__.__name__}):\n" + except Exception as e: + error_msg = ( + f"\nError in block: ({block.__class__.__name__}):\n" + f"Error details: {str(e)}\n" + f"Stack trace:\n{traceback.format_exc()}" + ) logger.error(error_msg) raise diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 46d0491b11cc..948e2b11fd1d 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -14,7 +14,6 @@ import inspect from typing import Any, List, Optional, Tuple, Union -from collections import OrderedDict import PIL import torch @@ -34,7 +33,13 @@ ) from ...utils.torch_utils import is_compiled_module, randn_tensor from ..controlnet.multicontrolnet import MultiControlNetModel -from ..modular_pipeline_builder import ModularPipelineBuilder, PipelineBlock, PipelineState, make_auto_step +from ..modular_pipeline_builder import ( + AutoPipelineBlocks, + ModularPipelineBuilder, + PipelineBlock, + PipelineState, + SequentialPipelineBlocks, +) from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import ( StableDiffusionXLPipelineOutput, @@ -130,6 +135,9 @@ def inputs(self) -> List[Tuple[str, Any]]: def intermediates_outputs(self) -> List[str]: return ["batch_size"] + def __init__(self): + super().__init__() + @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: prompt = state.get_input("prompt") @@ -148,7 +156,8 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLTextEncoderStep(PipelineBlock): - optional_components = ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"] + expected_components = ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"] + expected_configs = ["force_zeros_for_empty_prompt"] @property def inputs(self) -> List[Tuple[str, Any]]: @@ -315,8 +324,8 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: return pipeline, state -class StableDiffusionXLSetTimestepsStep(PipelineBlock): - required_components = ["scheduler"] +class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): + expected_components = ["scheduler"] @property def inputs(self) -> List[Tuple[str, Any]]: @@ -325,7 +334,6 @@ def inputs(self) -> List[Tuple[str, Any]]: ("timesteps", None), ("sigmas", None), ("denoising_end", None), - ("image", None), ("strength", 0.3), ("denoising_start", None), ("num_images_per_prompt", 1), @@ -352,7 +360,6 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: device = state.get_input("device") # image to image only - image = state.get_input("image") # just to check if it is an image to image workflow strength = state.get_input("strength") denoising_start = state.get_input("denoising_start") num_images_per_prompt = state.get_input("num_images_per_prompt") @@ -367,21 +374,68 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: pipeline.scheduler, num_inference_steps, device, timesteps, sigmas ) - if image is not None: + def denoising_value_valid(dnv): + return isinstance(dnv, float) and 0 < dnv < 1 - def denoising_value_valid(dnv): - return isinstance(dnv, float) and 0 < dnv < 1 + timesteps, num_inference_steps = pipeline.get_timesteps( + num_inference_steps, + strength, + device, + denoising_start=denoising_start if denoising_value_valid(denoising_start) else None, + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) - timesteps, num_inference_steps = pipeline.get_timesteps( - num_inference_steps, - strength, - device, - denoising_start=denoising_start if denoising_value_valid(denoising_start) else None, + if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: + discrete_timestep_cutoff = int( + round( + pipeline.scheduler.config.num_train_timesteps + - (denoising_end * pipeline.scheduler.config.num_train_timesteps) + ) ) - latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + state.add_intermediate("timesteps", timesteps) + state.add_intermediate("num_inference_steps", num_inference_steps) + state.add_intermediate("latent_timestep", latent_timestep) + + return pipeline, state - else: - latent_timestep = None + +class StableDiffusionXLSetTimestepsStep(PipelineBlock): + expected_components = ["scheduler"] + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + ("num_inference_steps", 50), + ("timesteps", None), + ("sigmas", None), + ("denoising_end", None), + ("device", None), + ] + + @property + def intermediates_outputs(self) -> List[str]: + return ["timesteps", "num_inference_steps"] + + def __init__(self, scheduler=None): + super().__init__(scheduler=scheduler) + + @torch.no_grad() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + num_inference_steps = state.get_input("num_inference_steps") + timesteps = state.get_input("timesteps") + sigmas = state.get_input("sigmas") + denoising_end = state.get_input("denoising_end") + device = state.get_input("device") + + if device is None: + device = pipeline._execution_device + + timesteps, num_inference_steps = retrieve_timesteps( + pipeline.scheduler, num_inference_steps, device, timesteps, sigmas + ) if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: discrete_timestep_cutoff = int( @@ -395,14 +449,13 @@ def denoising_value_valid(dnv): state.add_intermediate("timesteps", timesteps) state.add_intermediate("num_inference_steps", num_inference_steps) - state.add_intermediate("latent_timestep", latent_timestep) return pipeline, state -class StableDiffusionXLPrepareLatentsStep(PipelineBlock): - optional_components = ["vae", "scheduler"] - optional_auxiliaries = ["image_processor"] +class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): + expected_components = ["vae", "scheduler"] + expected_auxiliaries = ["image_processor"] @property def inputs(self) -> List[Tuple[str, Any]]: @@ -420,24 +473,89 @@ def inputs(self) -> List[Tuple[str, Any]]: @property def intermediates_inputs(self) -> List[str]: - return ["batch_size", "latent_timestep"] + return ["batch_size", "latent_timestep", "prompt_embeds"] @property def intermediates_outputs(self) -> List[str]: return ["latents"] - def __init__(self, vae=None, image_processor=None, vae_scale_factor=8, scheduler=None): - if image_processor is None: - image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) - super().__init__( - vae=vae, image_processor=image_processor, vae_scale_factor=vae_scale_factor, scheduler=scheduler - ) + def __init__(self, vae=None, scheduler=None): + super().__init__(vae=vae, scheduler=scheduler) + self.image_processor = VaeImageProcessor() + self.auxiliaries["image_processor"] = self.image_processor - @staticmethod - def check_inputs(pipeline, height, width, image): - if image is not None and (height is not None or width is not None): - raise ValueError("Cannot specify both `image` and `height` or `width`") + @torch.no_grad() + def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: + latents = state.get_input("latents") + num_images_per_prompt = state.get_input("num_images_per_prompt") + generator = state.get_input("generator") + device = state.get_input("device") + dtype = state.get_input("dtype") + + # image to image only + image = state.get_input("image") + denoising_start = state.get_input("denoising_start") + + batch_size = state.get_intermediate("batch_size") + prompt_embeds = state.get_intermediate("prompt_embeds", None) + # image to image only + latent_timestep = state.get_intermediate("latent_timestep", None) + + if dtype is None and prompt_embeds is not None: + dtype = prompt_embeds.dtype + elif dtype is None: + dtype = pipeline.vae.dtype + + if device is None: + device = pipeline._execution_device + + image = pipeline.image_processor.preprocess(image) + add_noise = True if denoising_start is None else False + if latents is None: + latents = pipeline.prepare_latents_img2img( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + dtype, + device, + generator, + add_noise, + ) + + state.add_intermediate("latents", latents) + + return pipeline, state + +class StableDiffusionXLPrepareLatentsStep(PipelineBlock): + expected_components = ["vae", "scheduler"] + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + ("height", None), + ("width", None), + ("generator", None), + ("latents", None), + ("num_images_per_prompt", 1), + ("device", None), + ("dtype", None), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return ["batch_size", "prompt_embeds"] + + @property + def intermediates_outputs(self) -> List[str]: + return ["latents"] + + def __init__(self, vae=None, scheduler=None): + super().__init__(vae=vae, scheduler=scheduler) + + @staticmethod + def check_inputs(pipeline, height, width): if ( height is not None and height % pipeline.vae_scale_factor != 0 @@ -460,61 +578,39 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin height = state.get_input("height") width = state.get_input("width") - # image to image only - image = state.get_input("image") - denoising_start = state.get_input("denoising_start") - batch_size = state.get_intermediate("batch_size") - prompt_embeds = state.get_intermediate("prompt_embeds", None) - # image to image only - latent_timestep = state.get_intermediate("latent_timestep", None) + prompt_embeds = state.get_intermediate("prompt_embeds") - if dtype is None and prompt_embeds is not None: + if dtype is None: dtype = prompt_embeds.dtype - elif dtype is None: - dtype = pipeline.vae.dtype if device is None: device = pipeline._execution_device - self.check_inputs(pipeline, height, width, image) - - if image is None: - height = height or pipeline.default_sample_size * pipeline.vae_scale_factor - width = width or pipeline.default_sample_size * pipeline.vae_scale_factor - num_channels_latents = pipeline.num_channels_latents - latents = pipeline.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - dtype, - device, - generator, - latents, - ) - else: - image = pipeline.image_processor.preprocess(image) - add_noise = True if denoising_start is None else False - if latents is None: - latents = pipeline.prepare_latents_img2img( - image, - latent_timestep, - batch_size, - num_images_per_prompt, - dtype, - device, - generator, - add_noise, - ) + self.check_inputs(pipeline, height, width) + + height = height or pipeline.default_sample_size * pipeline.vae_scale_factor + width = width or pipeline.default_sample_size * pipeline.vae_scale_factor + num_channels_latents = pipeline.num_channels_latents + latents = pipeline.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents, + ) state.add_intermediate("latents", latents) return pipeline, state -class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): - required_components = ["unet"] +class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): + expected_components = ["unet"] + expected_configs = ["requires_aesthetics_score"] @property def inputs(self) -> List[Tuple[str, Any]]: @@ -530,7 +626,6 @@ def inputs(self) -> List[Tuple[str, Any]]: ("aesthetic_score", 6.0), ("negative_aesthetic_score", 2.0), ("device", None), - ("image", None), ] @property @@ -557,7 +652,6 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin device = state.get_input("device") # image to image only - image = state.get_input("image") aesthetic_score = state.get_input("aesthetic_score") negative_aesthetic_score = state.get_input("negative_aesthetic_score") @@ -580,51 +674,123 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin else: text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) - if image is None: - add_time_ids = pipeline._get_add_time_ids( - original_size, - crops_coords_top_left, - target_size, - pooled_prompt_embeds.dtype, - text_encoder_projection_dim=text_encoder_projection_dim, - ) - add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1).to(device=device) - - if negative_original_size is not None and negative_target_size is not None: - negative_add_time_ids = pipeline._get_add_time_ids( - negative_original_size, - negative_crops_coords_top_left, - negative_target_size, - pooled_prompt_embeds.dtype, - text_encoder_projection_dim=text_encoder_projection_dim, - ) - else: - negative_add_time_ids = add_time_ids - negative_add_time_ids = negative_add_time_ids.repeat(batch_size * num_images_per_prompt, 1).to( - device=device - ) + if negative_original_size is None: + negative_original_size = original_size + if negative_target_size is None: + negative_target_size = target_size + + add_time_ids, negative_add_time_ids = pipeline._get_add_time_ids_img2img( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=pooled_prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1).to(device=device) + negative_add_time_ids = negative_add_time_ids.repeat(batch_size * num_images_per_prompt, 1).to(device=device) + + # Optionally get Guidance Scale Embedding for LCM + timestep_cond = None + if ( + hasattr(pipeline, "unet") + and pipeline.unet is not None + and pipeline.unet.config.time_cond_proj_dim is not None + ): + guidance_scale_tensor = torch.tensor(guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = pipeline.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=pipeline.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + state.add_intermediate("add_time_ids", add_time_ids) + state.add_intermediate("negative_add_time_ids", negative_add_time_ids) + state.add_intermediate("timestep_cond", timestep_cond) + return pipeline, state + + +class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): + expected_components = ["unet"] + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + ("original_size", None), + ("target_size", None), + ("negative_original_size", None), + ("negative_target_size", None), + ("crops_coords_top_left", (0, 0)), + ("negative_crops_coords_top_left", (0, 0)), + ("num_images_per_prompt", 1), + ("guidance_scale", 5.0), + ("device", None), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return ["latents", "batch_size", "pooled_prompt_embeds"] + + @property + def intermediates_outputs(self) -> List[str]: + return ["add_time_ids", "negative_add_time_ids", "timestep_cond"] + + def __init__(self, unet=None): + super().__init__(unet=unet) + + @torch.no_grad() + def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: + original_size = state.get_input("original_size") + target_size = state.get_input("target_size") + negative_original_size = state.get_input("negative_original_size") + negative_target_size = state.get_input("negative_target_size") + crops_coords_top_left = state.get_input("crops_coords_top_left") + negative_crops_coords_top_left = state.get_input("negative_crops_coords_top_left") + num_images_per_prompt = state.get_input("num_images_per_prompt") + guidance_scale = state.get_input("guidance_scale") + device = state.get_input("device") + + latents = state.get_intermediate("latents") + batch_size = state.get_intermediate("batch_size") + pooled_prompt_embeds = state.get_intermediate("pooled_prompt_embeds") + + if device is None: + device = pipeline._execution_device + + height, width = latents.shape[-2:] + height = height * pipeline.vae_scale_factor + width = width * pipeline.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + if hasattr(pipeline, "text_encoder_2") and pipeline.text_encoder_2 is not None: + text_encoder_projection_dim = pipeline.text_encoder_2.config.projection_dim else: - if negative_original_size is None: - negative_original_size = original_size - if negative_target_size is None: - negative_target_size = target_size - - add_time_ids, negative_add_time_ids = pipeline._get_add_time_ids_img2img( - original_size, - crops_coords_top_left, - target_size, - aesthetic_score, - negative_aesthetic_score, + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + + add_time_ids = pipeline._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + pooled_prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1).to(device=device) + + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = pipeline._get_add_time_ids( negative_original_size, negative_crops_coords_top_left, negative_target_size, - dtype=pooled_prompt_embeds.dtype, + pooled_prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) - add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1).to(device=device) - negative_add_time_ids = negative_add_time_ids.repeat(batch_size * num_images_per_prompt, 1).to( - device=device - ) + else: + negative_add_time_ids = add_time_ids + negative_add_time_ids = negative_add_time_ids.repeat(batch_size * num_images_per_prompt, 1).to(device=device) # Optionally get Guidance Scale Embedding for LCM timestep_cond = None @@ -645,8 +811,7 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin class StableDiffusionXLDenoiseStep(PipelineBlock): - required_components = ["unet", "scheduler"] - optional_auxiliaries = ["guider"] + expected_components = ["unet", "scheduler", "guider"] @property def inputs(self) -> List[Tuple[str, Any]]: @@ -780,8 +945,8 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): - required_components = ["unet", "controlnet", "scheduler"] - optional_auxiliaries = ["guider", "controlnet_guider", "control_image_processor"] + expected_components = ["unet", "controlnet", "scheduler", "guider", "controlnet_guider"] + expected_auxiliaries = ["control_image_processor"] @property def inputs(self) -> List[Tuple[str, Any]]: @@ -827,26 +992,24 @@ def __init__( scheduler=None, guider=None, controlnet_guider=None, - control_image_processor=None, vae_scale_factor=8.0, ): if guider is None: guider = CFGGuider() if controlnet_guider is None: controlnet_guider = CFGGuider() - if control_image_processor is None: - control_image_processor = VaeImageProcessor( - vae_scale_factor=vae_scale_factor, do_convert_rgb=True, do_normalize=False - ) super().__init__( unet=unet, controlnet=controlnet, scheduler=scheduler, guider=guider, controlnet_guider=controlnet_guider, - control_image_processor=control_image_processor, vae_scale_factor=vae_scale_factor, ) + control_image_processor = VaeImageProcessor( + vae_scale_factor=vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + self.auxiliaries["control_image_processor"] = control_image_processor @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: @@ -1017,6 +1180,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: if isinstance(controlnet_cond_scale, list): controlnet_cond_scale = controlnet_cond_scale[0] cond_scale = controlnet_cond_scale * controlnet_keep[i] + down_block_res_samples, mid_block_res_sample = pipeline.controlnet( pipeline.scheduler.scale_model_input(control_model_input, t), t, @@ -1069,8 +1233,8 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLDecodeLatentsStep(PipelineBlock): - optional_components = ["vae"] - optional_auxiliaries = ["image_processor"] + expected_components = ["vae"] + expected_auxiliaries = ["image_processor"] @property def inputs(self) -> List[Tuple[str, Any]]: @@ -1087,10 +1251,9 @@ def intermediates_inputs(self) -> List[str]: def intermediates_outputs(self) -> List[str]: return ["images"] - def __init__(self, vae=None, image_processor=None, vae_scale_factor=8): - if image_processor is None: - image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) - super().__init__(vae=vae, image_processor=image_processor, vae_scale_factor=vae_scale_factor) + def __init__(self, vae=None, vae_scale_factor=8): + super().__init__(vae=vae, vae_scale_factor=vae_scale_factor) + self.auxiliaries["image_processor"] = VaeImageProcessor(vae_scale_factor=vae_scale_factor) @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: @@ -1155,31 +1318,60 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: return pipeline, state -AUTO_DENOISE_BLOCK_MAP = OrderedDict([ - # Higher priority blocks first - ("control_image", StableDiffusionXLControlNetDenoiseStep), - # Default block - (None, StableDiffusionXLDenoiseStep), -]) +class StableDiffusionXLAutoSetTimestepsStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLSetTimestepsStep, StableDiffusionXLImg2ImgSetTimestepsStep] + block_prefixes = ["", "img2img"] + block_trigger_inputs = [None, "image"] -StableDiffusionXLAutoDenoiseStep = make_auto_step(AUTO_DENOISE_BLOCK_MAP) -class StableDiffusionXLModularPipeline( - ModularPipelineBuilder, - StableDiffusionMixin, - TextualInversionLoaderMixin, - StableDiffusionXLLoraLoaderMixin, -): - default_pipeline_blocks = [ +class StableDiffusionXLAutoPrepareLatentsStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareLatentsStep] + block_prefixes = ["", "img2img"] + block_trigger_inputs = [None, "image"] + + +class StableDiffusionXLAutoPrepareAdditionalConditioningStep(AutoPipelineBlocks): + block_classes = [ + StableDiffusionXLPrepareAdditionalConditioningStep, + StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, + ] + block_prefixes = ["", "img2img"] + block_trigger_inputs = [None, "image"] + + +class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLDenoiseStep, StableDiffusionXLControlNetDenoiseStep] + block_prefixes = ["", "controlnet"] + block_trigger_inputs = [None, "control_image"] + + +class StableDiffusionXLAllSteps(SequentialPipelineBlocks): + block_classes = [ StableDiffusionXLInputStep, StableDiffusionXLTextEncoderStep, - StableDiffusionXLSetTimestepsStep, - StableDiffusionXLPrepareLatentsStep, - StableDiffusionXLPrepareAdditionalConditioningStep, + StableDiffusionXLAutoSetTimestepsStep, + StableDiffusionXLAutoPrepareLatentsStep, + StableDiffusionXLAutoPrepareAdditionalConditioningStep, StableDiffusionXLAutoDenoiseStep, StableDiffusionXLDecodeLatentsStep, ] + block_prefixes = [ + "input", + "text_encoder", + "set_timesteps", + "prepare_latents", + "prepare_add_cond", + "denoise", + "decode_latents", + ] + +class StableDiffusionXLModularPipeline( + ModularPipelineBuilder, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, +): def __init__(self): super().__init__() diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index eaab67c93b18..d796a4968ea2 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -960,6 +960,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class ModularPipelineBuilder(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class PNDMPipeline(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index d8109eee6d35..02462b12d091 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1892,6 +1892,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class StableDiffusionXLModularPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class StableDiffusionXLPAGImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From a8df0f1ffb46282f2e56161a79d0d709e8a6e758 Mon Sep 17 00:00:00 2001 From: hlky Date: Tue, 10 Dec 2024 18:22:42 +0000 Subject: [PATCH 020/170] Modular APG (#10173) --- src/diffusers/guider.py | 230 ++++++++++++++++++ .../pipeline_stable_diffusion_xl_modular.py | 3 +- 2 files changed, 232 insertions(+), 1 deletion(-) diff --git a/src/diffusers/guider.py b/src/diffusers/guider.py index 96dced267baa..e58afe61574d 100644 --- a/src/diffusers/guider.py +++ b/src/diffusers/guider.py @@ -188,6 +188,7 @@ def apply_guidance( self, model_output: torch.Tensor, timestep: int = None, + latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: if not self.do_classifier_free_guidance: return model_output @@ -476,6 +477,7 @@ def apply_guidance( self, model_output: torch.Tensor, timestep: int, + latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: if not self.do_perturbed_attention_guidance: return model_output @@ -501,3 +503,231 @@ def apply_guidance( noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) return noise_pred + + +class MomentumBuffer: + def __init__(self, momentum: float): + self.momentum = momentum + self.running_average = 0 + + def update(self, update_value: torch.Tensor): + new_average = self.momentum * self.running_average + self.running_average = update_value + new_average + + +class APGGuider: + """ + This class is used to guide the pipeline with APG (Adaptive Projected Guidance). + """ + + def normalized_guidance( + self, + pred_cond: torch.Tensor, + pred_uncond: torch.Tensor, + guidance_scale: float, + momentum_buffer: MomentumBuffer = None, + norm_threshold: float = 0.0, + eta: float = 1.0, + ): + """ + Based on the findings of [Eliminating Oversaturation and Artifacts of High Guidance Scales + in Diffusion Models](https://arxiv.org/pdf/2410.02416) + """ + diff = pred_cond - pred_uncond + if momentum_buffer is not None: + momentum_buffer.update(diff) + diff = momentum_buffer.running_average + if norm_threshold > 0: + ones = torch.ones_like(diff) + diff_norm = diff.norm(p=2, dim=[-1, -2, -3], keepdim=True) + scale_factor = torch.minimum(ones, norm_threshold / diff_norm) + diff = diff * scale_factor + v0, v1 = diff.double(), pred_cond.double() + v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3]) + v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1 + v0_orthogonal = v0 - v0_parallel + diff_parallel, diff_orthogonal = v0_parallel.to(diff.dtype), v0_orthogonal.to(diff.dtype) + normalized_update = diff_orthogonal + eta * diff_parallel + pred_guided = pred_cond + (guidance_scale - 1) * normalized_update + return pred_guided + + @property + def adaptive_projected_guidance_momentum(self): + return self._adaptive_projected_guidance_momentum + + @property + def adaptive_projected_guidance_rescale_factor(self): + return self._adaptive_projected_guidance_rescale_factor + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 and not self._disable_guidance + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def batch_size(self): + return self._batch_size + + def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]): + disable_guidance = guider_kwargs.get("disable_guidance", False) + guidance_scale = guider_kwargs.get("guidance_scale", None) + if guidance_scale is None: + raise ValueError("guidance_scale is not provided in guider_kwargs") + adaptive_projected_guidance_momentum = guider_kwargs.get("adaptive_projected_guidance_momentum", None) + adaptive_projected_guidance_rescale_factor = guider_kwargs.get( + "adaptive_projected_guidance_rescale_factor", 15.0 + ) + guidance_rescale = guider_kwargs.get("guidance_rescale", 0.0) + batch_size = guider_kwargs.get("batch_size", None) + if batch_size is None: + raise ValueError("batch_size is not provided in guider_kwargs") + self._adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum + self._adaptive_projected_guidance_rescale_factor = adaptive_projected_guidance_rescale_factor + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._batch_size = batch_size + self._disable_guidance = disable_guidance + if adaptive_projected_guidance_momentum is not None: + self.momentum_buffer = MomentumBuffer(adaptive_projected_guidance_momentum) + else: + self.momentum_buffer = None + self.scheduler = pipeline.scheduler + + def reset_guider(self, pipeline): + pass + + def maybe_update_guider(self, pipeline, timestep): + pass + + def maybe_update_input(self, pipeline, cond_input): + pass + + def _maybe_split_prepared_input(self, cond): + """ + Process and potentially split the conditional input for Classifier-Free Guidance (CFG). + + This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`). + It determines whether to split the input based on its batch size relative to the expected batch size. + + Args: + cond (torch.Tensor): The conditional input tensor to process. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The negative conditional input (uncond_input) + - The positive conditional input (cond_input) + """ + if cond.shape[0] == self.batch_size * 2: + neg_cond = cond[0 : self.batch_size] + cond = cond[self.batch_size :] + return neg_cond, cond + elif cond.shape[0] == self.batch_size: + return cond, cond + else: + raise ValueError(f"Unsupported input shape: {cond.shape}") + + def _is_prepared_input(self, cond): + """ + Check if the input is already prepared for Classifier-Free Guidance (CFG). + + Args: + cond (torch.Tensor): The conditional input tensor to check. + + Returns: + bool: True if the input is already prepared, False otherwise. + """ + cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond + + return cond_tensor.shape[0] == self.batch_size * 2 + + def prepare_input( + self, + cond_input: Union[torch.Tensor, List[torch.Tensor]], + negative_cond_input: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """ + Prepare the input for CFG. + + Args: + cond_input (Union[torch.Tensor, List[torch.Tensor]]): + The conditional input. It can be a single tensor or a + list of tensors. It must have the same length as `negative_cond_input`. + negative_cond_input (Union[torch.Tensor, List[torch.Tensor]]): The negative conditional input. It can be a + single tensor or a list of tensors. It must have the same length as `cond_input`. + + Returns: + Union[torch.Tensor, List[torch.Tensor]]: The prepared input. + """ + + # we check if cond_input already has CFG applied, and split if it is the case. + if self._is_prepared_input(cond_input) and self.do_classifier_free_guidance: + return cond_input + + if self._is_prepared_input(cond_input) and not self.do_classifier_free_guidance: + if isinstance(cond_input, list): + negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input]) + else: + negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input) + + if not self._is_prepared_input(cond_input) and negative_cond_input is None: + raise ValueError( + "`negative_cond_input` is required when cond_input does not already contains negative conditional input" + ) + + if isinstance(cond_input, (list, tuple)): + if not self.do_classifier_free_guidance: + return cond_input + + if len(negative_cond_input) != len(cond_input): + raise ValueError("The length of negative_cond_input and cond_input must be the same.") + prepared_input = [] + for neg_cond, cond in zip(negative_cond_input, cond_input): + if neg_cond.shape[0] != cond.shape[0]: + raise ValueError("The batch size of negative_cond_input and cond_input must be the same.") + prepared_input.append(torch.cat([neg_cond, cond], dim=0)) + return prepared_input + + elif isinstance(cond_input, torch.Tensor): + if not self.do_classifier_free_guidance: + return cond_input + else: + return torch.cat([negative_cond_input, cond_input], dim=0) + + else: + raise ValueError(f"Unsupported input type: {type(cond_input)}") + + def apply_guidance( + self, + model_output: torch.Tensor, + timestep: int = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if not self.do_classifier_free_guidance: + return model_output + + if latents is None: + raise ValueError("APG requires `latents` to convert model output to denoised prediction (x0).") + + sigma = self.scheduler.sigmas[self.scheduler.step_index] + noise_pred = latents - sigma * model_output + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = self.normalized_guidance( + noise_pred_text, + noise_pred_uncond, + self.guidance_scale, + self.momentum_buffer, + self.adaptive_projected_guidance_rescale_factor, + ) + noise_pred = (latents - noise_pred) / sigma + + if self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + return noise_pred diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 948e2b11fd1d..971f7336cf76 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -926,6 +926,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: noise_pred = pipeline.guider.apply_guidance( noise_pred, timestep=t, + latents=latents, ) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype @@ -1213,7 +1214,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: return_dict=False, )[0] # perform guidance - noise_pred = pipeline.guider.apply_guidance(noise_pred, timestep=t) + noise_pred = pipeline.guider.apply_guidance(noise_pred, timestep=t, latents=latents) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = pipeline.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] From e50d614636aad7276219fd22a845fcc4e7fc5765 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 11 Dec 2024 03:39:39 +0100 Subject: [PATCH 021/170] only add model as expected_component when the model need to run for the block, currently it's added even when only config is needed --- src/diffusers/guider.py | 4 ++-- .../pipeline_stable_diffusion_xl_modular.py | 17 +++++++---------- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/src/diffusers/guider.py b/src/diffusers/guider.py index e58afe61574d..e00e8878d021 100644 --- a/src/diffusers/guider.py +++ b/src/diffusers/guider.py @@ -530,8 +530,8 @@ def normalized_guidance( eta: float = 1.0, ): """ - Based on the findings of [Eliminating Oversaturation and Artifacts of High Guidance Scales - in Diffusion Models](https://arxiv.org/pdf/2410.02416) + Based on the findings of [Eliminating Oversaturation and Artifacts of High Guidance Scales in Diffusion + Models](https://arxiv.org/pdf/2410.02416) """ diff = pred_cond - pred_uncond if momentum_buffer is not None: diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 971f7336cf76..c6435daf0ef9 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -529,7 +529,7 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin class StableDiffusionXLPrepareLatentsStep(PipelineBlock): - expected_components = ["vae", "scheduler"] + expected_components = ["scheduler"] @property def inputs(self) -> List[Tuple[str, Any]]: @@ -551,8 +551,8 @@ def intermediates_inputs(self) -> List[str]: def intermediates_outputs(self) -> List[str]: return ["latents"] - def __init__(self, vae=None, scheduler=None): - super().__init__(vae=vae, scheduler=scheduler) + def __init__(self, scheduler=None): + super().__init__(scheduler=scheduler) @staticmethod def check_inputs(pipeline, height, width): @@ -609,7 +609,6 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): - expected_components = ["unet"] expected_configs = ["requires_aesthetics_score"] @property @@ -636,8 +635,8 @@ def intermediates_inputs(self) -> List[str]: def intermediates_outputs(self) -> List[str]: return ["add_time_ids", "negative_add_time_ids", "timestep_cond"] - def __init__(self, unet=None, requires_aesthetics_score=False): - super().__init__(unet=unet, requires_aesthetics_score=requires_aesthetics_score) + def __init__(self, requires_aesthetics_score=False): + super().__init__(requires_aesthetics_score=requires_aesthetics_score) @torch.no_grad() def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: @@ -713,8 +712,6 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): - expected_components = ["unet"] - @property def inputs(self) -> List[Tuple[str, Any]]: return [ @@ -737,8 +734,8 @@ def intermediates_inputs(self) -> List[str]: def intermediates_outputs(self) -> List[str]: return ["add_time_ids", "negative_add_time_ids", "timestep_cond"] - def __init__(self, unet=None): - super().__init__(unet=unet) + def __init__(self): + super().__init__() @torch.no_grad() def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: From bc3d1c9ee6339080ac44b97bcc62f6e3c9805a69 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 14 Dec 2024 00:24:15 +0100 Subject: [PATCH 022/170] add model_cpu_offload_seq + _exlude_from_cpu_offload --- .../pipelines/modular_pipeline_builder.py | 105 +++++++++++++++++- .../pipeline_stable_diffusion_xl_modular.py | 2 + 2 files changed, 105 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/modular_pipeline_builder.py b/src/diffusers/pipelines/modular_pipeline_builder.py index 590165e170c5..ecc75cba5eb6 100644 --- a/src/diffusers/pipelines/modular_pipeline_builder.py +++ b/src/diffusers/pipelines/modular_pipeline_builder.py @@ -100,6 +100,9 @@ class PipelineBlock: expected_components = [] expected_auxiliaries = [] expected_configs = [] + model_cpu_offload_seq = None + _exclude_from_cpu_offload=[] + @property def inputs(self) -> Tuple[Tuple[str, Any], ...]: @@ -114,6 +117,24 @@ def intermediates_inputs(self) -> List[str]: def intermediates_outputs(self) -> List[str]: return [] + @property + def model_cpu_offload_seq(self): + + model_cpu_offload_seq = [] + block_component_names = set([k for k, v in self.components.items() if isinstance(v, torch.nn.Module)]) + if len(block_component_names) <= 1: + return None + else: + if self.model_cpu_offload_seq is None: + raise ValueError(f"Block {self.__class__.__name__} has multiple components but no model_cpu_offload_seq specified") + for model_str in self.model_cpu_offload_seq.split("->"): + if model_str in block_component_names: + model_cpu_offload_seq.append(block_component_names.pop(model_str)) + if len(block_component_names) > 0: + raise ValueError(f"Block {self.__class__.__name__} has components {block_component_names} that are not in model_cpu_offload_seq {self.model_cpu_offload_seq}") + return "->".join(model_cpu_offload_seq) + + def update_states(self, **kwargs): """ Update components and configs after instance creation. Auxiliaries (e.g. image_processor) should be defined for @@ -271,7 +292,8 @@ class MultiPipelineBlocks: """ block_classes = [] - block_names = [] + block_prefixes = [] + model_cpu_offload_seq = None @property def expected_components(self): @@ -344,6 +366,14 @@ def intermediates_inputs(self) -> List[str]: def intermediates_outputs(self) -> List[str]: raise NotImplementedError("intermediates_outputs property must be implemented in subclasses") + @property + def model_cpu_offload_seq(self): + raise NotImplementedError("model_cpu_offload_seq property must be implemented in subclasses") + + @property + def _exclude_from_cpu_offload(self): + raise NotImplementedError("_exclude_from_cpu_offload property must be implemented in subclasses") + def __call__(self, pipeline, state): raise NotImplementedError("__call__ method must be implemented in subclasses") @@ -524,7 +554,7 @@ def __post_init__(self): # Map trigger inputs to block objects self.trigger_to_block_map = dict(zip(self.block_trigger_inputs, self.blocks.values())) - + @property def inputs(self) -> List[Tuple[str, Any]]: return combine_inputs(*(block.inputs for block in self.blocks.values())) @@ -573,6 +603,36 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: ) logger.error(error_msg) raise + + @property + def model_cpu_offload_seq(self): + """ + This is a sequence of model names that are expected to be offloaded to CPU + """ + model_cpu_offload_seq = None + for block_name, block in self.blocks.items(): + if block.model_cpu_offload_seq is not None: + if model_cpu_offload_seq is None: + model_cpu_offload_seq = block.model_cpu_offload_seq + else: + if len(block.model_cpu_offload_seq.split("->")) > len(model_cpu_offload_seq.split("->")): + model_cpu_offload_seq = block.model_cpu_offload_seq + return model_cpu_offload_seq + + + @property + def _exclude_from_cpu_offload(self): + model_cpu_offload_seq = None + for block_name, block in self.blocks.items(): + if block.model_cpu_offload_seq is not None: + if model_cpu_offload_seq is None: + model_cpu_offload_seq = block.model_cpu_offload_seq + else: + if len(block.model_cpu_offload_seq.split("->")) > len(model_cpu_offload_seq.split("->")): + model_cpu_offload_seq = block.model_cpu_offload_seq + return model_cpu_offload_seq + + def __repr__(self): class_name = self.__class__.__name__ @@ -697,6 +757,34 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: raise return pipeline, state + @property + def model_cpu_offload_seq(self): + model_cpu_offload_seq = [] + + for block_name, block in self.blocks.items(): + block_components_names = set([k for k, v in block.components.items() if isinstance(v, torch.nn.Module)]) + block_components_names = [b for b in block_components_names if b not in self.block._exclude_from_cpu_offload] + if len(block_components_names) == 0: + continue + if len(block_components_names) == 1: + model_cpu_offload_seq.append(block_components_names.pop()) + else: + if block.model_cpu_offload_seq is None: + raise ValueError(f"Block {block_name}:{block.__class__.__name__} has multiple components {block_components_names} but no model_cpu_offload_seq specified") + for model_str in block.model_cpu_offload_seq.split("->"): + if model_str in block_components_names: + model_cpu_offload_seq.append(block_components_names.pop(model_str)) + if len(block_components_names) > 0: + raise ValueError(f"Block {block_name}:{block.__class__.__name__} has components {block_components_names} that are not in model_cpu_offload_seq {block.model_cpu_offload_seq}") + return "->".join(model_cpu_offload_seq) + + @property + def _exclude_from_cpu_offload(self): + exclude_from_cpu_offload = set() + for block in self.blocks.values(): + exclude_from_cpu_offload.update(block._exclude_from_cpu_offload) + return list(exclude_from_cpu_offload) + def __repr__(self): class_name = self.__class__.__name__ @@ -1372,3 +1460,16 @@ def module_is_offloaded(module): " `torch_dtype=torch.float16` argument, or use another device for inference." ) return self + + + def remove_all_hooks(self): + for _, model in self.components.items(): + if isinstance(model, torch.nn.Module) and hasattr(model, "_hf_hook"): + accelerate.hooks.remove_hook_from_module(model, recurse=True) + self._all_hooks = [] + + def find_model_sequence(self): + pass + + + # def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda", model_cpu_offload_seq: Optional[List[str]] = None): \ No newline at end of file diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index c6435daf0ef9..c8126532f9b1 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -158,6 +158,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLTextEncoderStep(PipelineBlock): expected_components = ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"] expected_configs = ["force_zeros_for_empty_prompt"] + model_cpu_offload_seq = "text_encoder->text_encoder_2" @property def inputs(self) -> List[Tuple[str, Any]]: @@ -945,6 +946,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): expected_components = ["unet", "controlnet", "scheduler", "guider", "controlnet_guider"] expected_auxiliaries = ["control_image_processor"] + _exclude_from_cpu_offload = ["controlnet"] @property def inputs(self) -> List[Tuple[str, Any]]: From 2b3cd2d39caa65093600939f9abeb300001ffe9e Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 14 Dec 2024 03:02:31 +0100 Subject: [PATCH 023/170] update --- .../pipelines/modular_pipeline_builder.py | 69 ++++++------------- .../pipeline_stable_diffusion_xl_modular.py | 4 +- 2 files changed, 22 insertions(+), 51 deletions(-) diff --git a/src/diffusers/pipelines/modular_pipeline_builder.py b/src/diffusers/pipelines/modular_pipeline_builder.py index ecc75cba5eb6..1f6967ffc80b 100644 --- a/src/diffusers/pipelines/modular_pipeline_builder.py +++ b/src/diffusers/pipelines/modular_pipeline_builder.py @@ -100,8 +100,7 @@ class PipelineBlock: expected_components = [] expected_auxiliaries = [] expected_configs = [] - model_cpu_offload_seq = None - _exclude_from_cpu_offload=[] + _model_cpu_offload_seq = None @property @@ -119,19 +118,23 @@ def intermediates_outputs(self) -> List[str]: @property def model_cpu_offload_seq(self): + """ + adjust the model_cpu_offload_seq to reflect actual components loaded in the block + """ model_cpu_offload_seq = [] - block_component_names = set([k for k, v in self.components.items() if isinstance(v, torch.nn.Module)]) - if len(block_component_names) <= 1: + block_component_names = [k for k, v in self.components.items() if isinstance(v, torch.nn.Module)] + if len(block_component_names) ==0: return None + if len(block_component_names) == 1: + return block_component_names[0] else: - if self.model_cpu_offload_seq is None: + if self._model_cpu_offload_seq is None: raise ValueError(f"Block {self.__class__.__name__} has multiple components but no model_cpu_offload_seq specified") - for model_str in self.model_cpu_offload_seq.split("->"): - if model_str in block_component_names: - model_cpu_offload_seq.append(block_component_names.pop(model_str)) - if len(block_component_names) > 0: - raise ValueError(f"Block {self.__class__.__name__} has components {block_component_names} that are not in model_cpu_offload_seq {self.model_cpu_offload_seq}") + model_cpu_offload_seq = [m for m in self._model_cpu_offload_seq.split("->") if m in block_component_names] + remaining = [m for m in block_component_names if m not in model_cpu_offload_seq] + if remaining: + logger.warning(f"Block {self.__class__.__name__} has components {remaining} that are not in model_cpu_offload_seq {self._model_cpu_offload_seq}") return "->".join(model_cpu_offload_seq) @@ -293,7 +296,7 @@ class MultiPipelineBlocks: block_classes = [] block_prefixes = [] - model_cpu_offload_seq = None + _model_cpu_offload_seq = None @property def expected_components(self): @@ -369,10 +372,6 @@ def intermediates_outputs(self) -> List[str]: @property def model_cpu_offload_seq(self): raise NotImplementedError("model_cpu_offload_seq property must be implemented in subclasses") - - @property - def _exclude_from_cpu_offload(self): - raise NotImplementedError("_exclude_from_cpu_offload property must be implemented in subclasses") def __call__(self, pipeline, state): raise NotImplementedError("__call__ method must be implemented in subclasses") @@ -606,32 +605,10 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: @property def model_cpu_offload_seq(self): - """ - This is a sequence of model names that are expected to be offloaded to CPU - """ - model_cpu_offload_seq = None - for block_name, block in self.blocks.items(): - if block.model_cpu_offload_seq is not None: - if model_cpu_offload_seq is None: - model_cpu_offload_seq = block.model_cpu_offload_seq - else: - if len(block.model_cpu_offload_seq.split("->")) > len(model_cpu_offload_seq.split("->")): - model_cpu_offload_seq = block.model_cpu_offload_seq - return model_cpu_offload_seq - - - @property - def _exclude_from_cpu_offload(self): - model_cpu_offload_seq = None - for block_name, block in self.blocks.items(): - if block.model_cpu_offload_seq is not None: - if model_cpu_offload_seq is None: - model_cpu_offload_seq = block.model_cpu_offload_seq - else: - if len(block.model_cpu_offload_seq.split("->")) > len(model_cpu_offload_seq.split("->")): - model_cpu_offload_seq = block.model_cpu_offload_seq - return model_cpu_offload_seq + default_block = self.trigger_to_block_map.get(None) + + return default_block.model_cpu_offload_seq def __repr__(self): @@ -763,7 +740,6 @@ def model_cpu_offload_seq(self): for block_name, block in self.blocks.items(): block_components_names = set([k for k, v in block.components.items() if isinstance(v, torch.nn.Module)]) - block_components_names = [b for b in block_components_names if b not in self.block._exclude_from_cpu_offload] if len(block_components_names) == 0: continue if len(block_components_names) == 1: @@ -773,17 +749,12 @@ def model_cpu_offload_seq(self): raise ValueError(f"Block {block_name}:{block.__class__.__name__} has multiple components {block_components_names} but no model_cpu_offload_seq specified") for model_str in block.model_cpu_offload_seq.split("->"): if model_str in block_components_names: - model_cpu_offload_seq.append(block_components_names.pop(model_str)) + model_cpu_offload_seq.append(model_str) + block_components_names.remove(model_str) if len(block_components_names) > 0: - raise ValueError(f"Block {block_name}:{block.__class__.__name__} has components {block_components_names} that are not in model_cpu_offload_seq {block.model_cpu_offload_seq}") + logger.warning(f"Block {block_name}:{block.__class__.__name__} has components {block_components_names} that are not in model_cpu_offload_seq {block.model_cpu_offload_seq}") return "->".join(model_cpu_offload_seq) - @property - def _exclude_from_cpu_offload(self): - exclude_from_cpu_offload = set() - for block in self.blocks.values(): - exclude_from_cpu_offload.update(block._exclude_from_cpu_offload) - return list(exclude_from_cpu_offload) def __repr__(self): class_name = self.__class__.__name__ diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index c8126532f9b1..56df056c78e3 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -158,7 +158,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLTextEncoderStep(PipelineBlock): expected_components = ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"] expected_configs = ["force_zeros_for_empty_prompt"] - model_cpu_offload_seq = "text_encoder->text_encoder_2" + _model_cpu_offload_seq = "text_encoder->text_encoder_2" @property def inputs(self) -> List[Tuple[str, Any]]: @@ -946,7 +946,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): expected_components = ["unet", "controlnet", "scheduler", "guider", "controlnet_guider"] expected_auxiliaries = ["control_image_processor"] - _exclude_from_cpu_offload = ["controlnet"] + _model_cpu_offload_seq = "unet" @property def inputs(self) -> List[Tuple[str, Any]]: From b305c779b2d197c0f7f2f48f5a9e761ff77e6eaa Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 14 Dec 2024 21:37:21 +0100 Subject: [PATCH 024/170] add offload support! --- .../pipelines/modular_pipeline_builder.py | 190 ++++++++++++++---- 1 file changed, 152 insertions(+), 38 deletions(-) diff --git a/src/diffusers/pipelines/modular_pipeline_builder.py b/src/diffusers/pipelines/modular_pipeline_builder.py index 1f6967ffc80b..faedada96073 100644 --- a/src/diffusers/pipelines/modular_pipeline_builder.py +++ b/src/diffusers/pipelines/modular_pipeline_builder.py @@ -16,7 +16,7 @@ import warnings from collections import OrderedDict from dataclasses import dataclass, field -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch from tqdm.auto import tqdm @@ -102,7 +102,6 @@ class PipelineBlock: expected_configs = [] _model_cpu_offload_seq = None - @property def inputs(self) -> Tuple[Tuple[str, Any], ...]: # (input_name, default_value) @@ -124,20 +123,23 @@ def model_cpu_offload_seq(self): model_cpu_offload_seq = [] block_component_names = [k for k, v in self.components.items() if isinstance(v, torch.nn.Module)] - if len(block_component_names) ==0: + if len(block_component_names) == 0: return None if len(block_component_names) == 1: return block_component_names[0] else: if self._model_cpu_offload_seq is None: - raise ValueError(f"Block {self.__class__.__name__} has multiple components but no model_cpu_offload_seq specified") + raise ValueError( + f"Block {self.__class__.__name__} has multiple components but no model_cpu_offload_seq specified" + ) model_cpu_offload_seq = [m for m in self._model_cpu_offload_seq.split("->") if m in block_component_names] remaining = [m for m in block_component_names if m not in model_cpu_offload_seq] if remaining: - logger.warning(f"Block {self.__class__.__name__} has components {remaining} that are not in model_cpu_offload_seq {self._model_cpu_offload_seq}") + logger.warning( + f"Block {self.__class__.__name__} has components {remaining} that are not in model_cpu_offload_seq {self._model_cpu_offload_seq}" + ) return "->".join(model_cpu_offload_seq) - - + def update_states(self, **kwargs): """ Update components and configs after instance creation. Auxiliaries (e.g. image_processor) should be defined for @@ -553,7 +555,7 @@ def __post_init__(self): # Map trigger inputs to block objects self.trigger_to_block_map = dict(zip(self.block_trigger_inputs, self.blocks.values())) - + @property def inputs(self) -> List[Tuple[str, Any]]: return combine_inputs(*(block.inputs for block in self.blocks.values())) @@ -602,15 +604,13 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: ) logger.error(error_msg) raise - + @property def model_cpu_offload_seq(self): - default_block = self.trigger_to_block_map.get(None) return default_block.model_cpu_offload_seq - def __repr__(self): class_name = self.__class__.__name__ @@ -669,11 +669,6 @@ def __repr__(self): if intermediates_str: blocks_str += f" intermediates: {intermediates_str}\n" blocks_str += "\n" - - # Pipeline interface information - inputs_str = " PipelineBlock Interface:\n" - inputs_str += " Inputs:\n" + "\n".join(f" - {name}={default}" for name, default in self.inputs) - intermediates_str = ( "\n Intermediates:\n" f" - inputs: {', '.join(self.intermediates_inputs)}\n" @@ -686,7 +681,6 @@ def __repr__(self): f"{auxiliaries_str}\n" f"{configs_str}\n" f"{blocks_str}\n" - f"{inputs_str}" f"{intermediates_str}\n" f")" ) @@ -737,24 +731,33 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: @property def model_cpu_offload_seq(self): model_cpu_offload_seq = [] - + for block_name, block in self.blocks.items(): - block_components_names = set([k for k, v in block.components.items() if isinstance(v, torch.nn.Module)]) - if len(block_components_names) == 0: + block_components = [k for k, v in block.components.items() if isinstance(v, torch.nn.Module)] + if len(block_components) == 0: continue - if len(block_components_names) == 1: - model_cpu_offload_seq.append(block_components_names.pop()) + if len(block_components) == 1: + if block_components[0] in model_cpu_offload_seq: + model_cpu_offload_seq.remove(block_components[0]) + model_cpu_offload_seq.append(block_components[0]) else: if block.model_cpu_offload_seq is None: - raise ValueError(f"Block {block_name}:{block.__class__.__name__} has multiple components {block_components_names} but no model_cpu_offload_seq specified") + raise ValueError( + f"Block {block_name}:{block.__class__.__name__} has multiple components {block_components} but no model_cpu_offload_seq specified" + ) for model_str in block.model_cpu_offload_seq.split("->"): - if model_str in block_components_names: + if model_str in block_components: + # if it is already in the list,remove previous occurence and add to the end + if model_str in model_cpu_offload_seq: + model_cpu_offload_seq.remove(model_str) model_cpu_offload_seq.append(model_str) - block_components_names.remove(model_str) - if len(block_components_names) > 0: - logger.warning(f"Block {block_name}:{block.__class__.__name__} has components {block_components_names} that are not in model_cpu_offload_seq {block.model_cpu_offload_seq}") + block_components.remove(model_str) + if len(block_components) > 0: + logger.warning( + f"Block {block_name}:{block.__class__.__name__} has components {block_components} that are not in model_cpu_offload_seq {block.model_cpu_offload_seq}" + ) + return "->".join(model_cpu_offload_seq) - def __repr__(self): class_name = self.__class__.__name__ @@ -810,10 +813,6 @@ def __repr__(self): blocks_str += f" intermediates: {intermediates_str}\n" blocks_str += "\n" - # Pipeline interface information - inputs_str = " PipelineBlock Interface:\n" - inputs_str += " Inputs:\n" + "\n".join(f" - {name}={default}" for name, default in self.inputs) - intermediates_str = ( "\n Intermediates:\n" f" - inputs: {', '.join(self.intermediates_inputs)}\n" @@ -826,7 +825,6 @@ def __repr__(self): f"{auxiliaries_str}\n" f"{configs_str}\n" f"{blocks_str}\n" - f"{inputs_str}" f"{intermediates_str}\n" f")" ) @@ -1004,6 +1002,7 @@ def set_progress_bar_config(self, **kwargs): def __call__(self, *args, **kwargs): raise NotImplementedError("__call__ is not implemented for ModularPipelineBuilder") + # YiYi Notes: do we need to support multiple blocks? def remove_blocks(self, indices: Union[int, List[int]]): """ Remove one or more blocks from the pipeline by their indices and clean up associated components, configs, and @@ -1060,6 +1059,8 @@ def remove_blocks(self, indices: Union[int, List[int]]): if config_name in self.config: del self.config[config_name] + # YiYi Notes: I left all the functionalities to support adding multiple blocks + # but I wonder if it is still needed now we have `SequentialBlocks` and user can always combine them into one before adding to the builder def add_blocks(self, pipeline_blocks, at: int = -1): """Add blocks to the pipeline. @@ -1171,6 +1172,7 @@ def run_blocks(self, state: PipelineState = None, **kwargs): error_msg = f"Error in block: ({block.__class__.__name__}):\n" logger.error(error_msg) raise + self.maybe_free_model_hooks() return state @@ -1206,6 +1208,7 @@ def run_pipeline(self, **kwargs): ) logger.error(error_msg) raise + self.maybe_free_model_hooks() return state.get_output("images") @@ -1226,7 +1229,14 @@ def __repr__(self): output += "Pipeline Blocks:\n" output += "----------------\n" for i, block in enumerate(self.pipeline_blocks): - output += f"{i}. {block.__class__.__name__}\n" + if isinstance(block, MultiPipelineBlocks): + output += f"{i}. {block.__class__.__name__} - (CPU offload seq: {block.model_cpu_offload_seq})\n" + # Add sub-blocks information + for sub_block_name, sub_block in block.blocks.items(): + output += f" • {sub_block_name} ({sub_block.__class__.__name__}) \n" + else: + output += f"{i}. {block.__class__.__name__} - (CPU offload seq: {block.model_cpu_offload_seq})\n" + output += "\n" intermediates_str = "" if hasattr(block, "intermediates_inputs"): @@ -1432,15 +1442,119 @@ def module_is_offloaded(module): ) return self - def remove_all_hooks(self): for _, model in self.components.items(): if isinstance(model, torch.nn.Module) and hasattr(model, "_hf_hook"): accelerate.hooks.remove_hook_from_module(model, recurse=True) self._all_hooks = [] - + def find_model_sequence(self): pass - - # def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda", model_cpu_offload_seq: Optional[List[str]] = None): \ No newline at end of file + # YiYi notes: assume there is only one pipeline block now (still debating if we want to support multiple pipeline blocks) + @property + def model_cpu_offload_seq(self): + return self.pipeline_blocks[0].model_cpu_offload_seq + + def enable_model_cpu_offload( + self, + gpu_id: Optional[int] = None, + device: Union[torch.device, str] = "cuda", + model_cpu_offload_seq: Optional[str] = None, + ): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + + Arguments: + gpu_id (`int`, *optional*): + The ID of the accelerator that shall be used in inference. If not specified, it will default to 0. + device (`torch.Device` or `str`, *optional*, defaults to "cuda"): + The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will + default to "cuda". + """ + _exclude_from_cpu_offload = [] # YiYi Notes: this is not used (keep the variable for now) + is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1 + if is_pipeline_device_mapped: + raise ValueError( + "It seems like you have activated a device mapping strategy on the pipeline so calling `enable_model_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_model_cpu_offload()`." + ) + + model_cpu_offload_seq = model_cpu_offload_seq or self.model_cpu_offload_seq + self._model_cpu_offload_seq_used = model_cpu_offload_seq + if model_cpu_offload_seq is None: + raise ValueError( + "Model CPU offload cannot be enabled because no `model_cpu_offload_seq` class attribute is set or passed." + ) + + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + self.remove_all_hooks() + + torch_device = torch.device(device) + device_index = torch_device.index + + if gpu_id is not None and device_index is not None: + raise ValueError( + f"You have passed both `gpu_id`={gpu_id} and an index as part of the passed device `device`={device}" + f"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of the device: `device`={torch_device.type}" + ) + + # _offload_gpu_id should be set to passed gpu_id (or id in passed `device`) or default to previously set id or default to 0 + self._offload_gpu_id = gpu_id or torch_device.index or getattr(self, "_offload_gpu_id", 0) + + device_type = torch_device.type + device = torch.device(f"{device_type}:{self._offload_gpu_id}") + self._offload_device = device + + self.to("cpu", silence_dtype_warnings=True) + device_mod = getattr(torch, device.type, None) + if hasattr(device_mod, "empty_cache") and device_mod.is_available(): + device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + all_model_components = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)} + + self._all_hooks = [] + hook = None + for model_str in model_cpu_offload_seq.split("->"): + model = all_model_components.pop(model_str, None) + if not isinstance(model, torch.nn.Module): + continue + + _, hook = cpu_offload_with_hook(model, device, prev_module_hook=hook) + self._all_hooks.append(hook) + + # CPU offload models that are not in the seq chain unless they are explicitly excluded + # these models will stay on CPU until maybe_free_model_hooks is called + # some models cannot be in the seq chain because they are iteratively called, such as controlnet + for name, model in all_model_components.items(): + if not isinstance(model, torch.nn.Module): + continue + + if name in _exclude_from_cpu_offload: + model.to(device) + else: + _, hook = cpu_offload_with_hook(model, device) + self._all_hooks.append(hook) + + def maybe_free_model_hooks(self): + r""" + Function that offloads all components, removes all model hooks that were added when using + `enable_model_cpu_offload` and then applies them again. In case the model has not been offloaded this function + is a no-op. Make sure to add this function to the end of the `__call__` function of your pipeline so that it + functions correctly when applying enable_model_cpu_offload. + """ + if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0: + # `enable_model_cpu_offload` has not be called, so silently do nothing + return + + # make sure the model is in the same state as before calling it + self.enable_model_cpu_offload( + device=getattr(self, "_offload_device", "cuda"), + model_cpu_offload_seq=getattr(self, "_model_cpu_offload_seq_used", None), + ) From 0b90051db8da106c6c5a47fddac90b03165dc970 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 19 Dec 2024 17:57:12 +0100 Subject: [PATCH 025/170] add vae encoder node --- .../pipeline_controlnet_sd_xl_img2img.py | 11 +- .../kolors/pipeline_kolors_img2img.py | 11 +- .../pipeline_pag_controlnet_sd_xl_img2img.py | 11 +- .../pag/pipeline_pag_sd_xl_img2img.py | 11 +- .../pipeline_stable_diffusion_xl_img2img.py | 11 +- .../pipeline_stable_diffusion_xl_modular.py | 111 ++++++++++++++++-- 6 files changed, 128 insertions(+), 38 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index 21cd87f7570e..c8b3bc209c0f 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -899,12 +899,6 @@ def prepare_latents( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) - latents_mean = latents_std = None - if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: - latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: - latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) - # Offload text encoder if `enable_model_cpu_offload` was enabled if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.text_encoder_2.to("cpu") @@ -918,6 +912,11 @@ def prepare_latents( init_latents = image else: + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) # make sure the VAE is in float32 mode, as it overflows in float16 if self.vae.config.force_upcast: image = image.float() diff --git a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py index 4985a80f88df..37123d854da9 100644 --- a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py +++ b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py @@ -607,12 +607,6 @@ def prepare_latents( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) - latents_mean = latents_std = None - if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: - latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: - latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) - # Offload text encoder if `enable_model_cpu_offload` was enabled if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.text_encoder_2.to("cpu") @@ -626,6 +620,11 @@ def prepare_latents( init_latents = image else: + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) # make sure the VAE is in float32 mode, as it overflows in float16 if self.vae.config.force_upcast: image = image.float() diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py index 66398483e046..06f02d7b0065 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py @@ -905,12 +905,6 @@ def prepare_latents( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) - latents_mean = latents_std = None - if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: - latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: - latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) - # Offload text encoder if `enable_model_cpu_offload` was enabled if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.text_encoder_2.to("cpu") @@ -924,6 +918,11 @@ def prepare_latents( init_latents = image else: + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) # make sure the VAE is in float32 mode, as it overflows in float16 if self.vae.config.force_upcast: image = image.float() diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py index 4c2c4e5aa3fa..62de8ba41927 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py @@ -691,12 +691,6 @@ def prepare_latents( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) - latents_mean = latents_std = None - if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: - latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: - latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) - # Offload text encoder if `enable_model_cpu_offload` was enabled if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.text_encoder_2.to("cpu") @@ -710,6 +704,11 @@ def prepare_latents( init_latents = image else: + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) # make sure the VAE is in float32 mode, as it overflows in float16 if self.vae.config.force_upcast: image = image.float() diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 29b5e11875fc..050c1009b136 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -682,12 +682,6 @@ def prepare_latents( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) - latents_mean = latents_std = None - if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: - latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: - latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) - # Offload text encoder if `enable_model_cpu_offload` was enabled if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.text_encoder_2.to("cpu") @@ -701,6 +695,11 @@ def prepare_latents( init_latents = image else: + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) # make sure the VAE is in float32 mode, as it overflows in float16 if self.vae.config.force_upcast: image = image.float() diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 56df056c78e3..47229594363a 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -325,6 +325,102 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: return pipeline, state +class StableDiffusionXLVAEEncoderStep(PipelineBlock): + expected_components = ["vae"] + expected_auxiliaries = ["image_processor"] + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + ("image", None), + ("generator", None), + ("height", None), + ("width", None), + ("device", None), + ("dtype", None), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return ["batch_size"] + + @property + def intermediates_outputs(self) -> List[str]: + return ["image_latents"] + + def __init__(self, vae=None): + super().__init__(vae=vae) + self.image_processor = VaeImageProcessor() + self.auxiliaries["image_processor"] = self.image_processor + + @torch.no_grad() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + image = state.get_input("image") + generator = state.get_input("generator") + height = state.get_input("height") + width = state.get_input("width") + device = state.get_input("device") + dtype = state.get_input("dtype") + + batch_size = state.get_intermediate("batch_size") + + if device is None: + device = pipeline._execution_device + if dtype is None: + dtype = pipeline.vae.dtype + + image = pipeline.image_processor.preprocess(image, height=height, width=width) + image = image.to(device=device, dtype=dtype) + + latents_mean = latents_std = None + if hasattr(pipeline.vae.config, "latents_mean") and pipeline.vae.config.latents_mean is not None: + latents_mean = torch.tensor(pipeline.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(pipeline.vae.config, "latents_std") and pipeline.vae.config.latents_std is not None: + latents_std = torch.tensor(pipeline.vae.config.latents_std).view(1, 4, 1, 1) + + # make sure the VAE is in float32 mode, as it overflows in float16 + if pipeline.vae.config.force_upcast: + image = image.float() + pipeline.vae.to(dtype=torch.float32) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + + init_latents = [ + retrieve_latents(pipeline.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(pipeline.vae.encode(image), generator=generator) + + if pipeline.vae.config.force_upcast: + pipeline.vae.to(dtype) + + init_latents = init_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=device, dtype=dtype) + latents_std = latents_std.to(device=device, dtype=dtype) + init_latents = (init_latents - latents_mean) * pipeline.vae.config.scaling_factor / latents_std + else: + init_latents = pipeline.vae.config.scaling_factor * init_latents + + state.add_intermediate("image_latents", init_latents) + + return pipeline, state + + class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): expected_components = ["scheduler"] @@ -498,9 +594,9 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin denoising_start = state.get_input("denoising_start") batch_size = state.get_intermediate("batch_size") - prompt_embeds = state.get_intermediate("prompt_embeds", None) + prompt_embeds = state.get_intermediate("prompt_embeds") # image to image only - latent_timestep = state.get_intermediate("latent_timestep", None) + latent_timestep = state.get_intermediate("latent_timestep") if dtype is None and prompt_embeds is not None: dtype = prompt_embeds.dtype @@ -1872,12 +1968,6 @@ def prepare_latents_img2img( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) - latents_mean = latents_std = None - if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: - latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: - latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) - # Offload text encoder if `enable_model_cpu_offload` was enabled if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.text_encoder_2.to("cpu") @@ -1891,6 +1981,11 @@ def prepare_latents_img2img( init_latents = image else: + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) # make sure the VAE is in float32 mode, as it overflows in float16 if self.vae.config.force_upcast: image = image.float() From 4fa85c796316b3edee90e24f17163eab41efc1dd Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 31 Dec 2024 02:57:42 +0100 Subject: [PATCH 026/170] add model_manager and global offloading method --- src/diffusers/guider.py | 18 +- ...pipeline_controlnet_union_sd_xl_img2img.py | 11 +- src/diffusers/pipelines/model_manager.py | 316 ++++++++++++++++++ .../pipeline_stable_diffusion_xl_modular.py | 2 +- 4 files changed, 337 insertions(+), 10 deletions(-) create mode 100644 src/diffusers/pipelines/model_manager.py diff --git a/src/diffusers/guider.py b/src/diffusers/guider.py index e00e8878d021..6be2f47199ab 100644 --- a/src/diffusers/guider.py +++ b/src/diffusers/guider.py @@ -32,9 +32,21 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): - """ - Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py index 95cf067fce12..38076df9e442 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py @@ -877,12 +877,6 @@ def prepare_latents( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) - latents_mean = latents_std = None - if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: - latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: - latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) - # Offload text encoder if `enable_model_cpu_offload` was enabled if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.text_encoder_2.to("cpu") @@ -896,6 +890,11 @@ def prepare_latents( init_latents = image else: + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) # make sure the VAE is in float32 mode, as it overflows in float16 if self.vae.config.force_upcast: image = image.float() diff --git a/src/diffusers/pipelines/model_manager.py b/src/diffusers/pipelines/model_manager.py new file mode 100644 index 000000000000..ea7e552869d4 --- /dev/null +++ b/src/diffusers/pipelines/model_manager.py @@ -0,0 +1,316 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict +from itertools import combinations +from typing import List, Optional, Union + +import torch + +from ..utils import ( + is_accelerate_available, + logging, +) + + +if is_accelerate_available(): + from accelerate.hooks import ModelHook, add_hook_to_module, remove_hook_from_module + from accelerate.state import PartialState + from accelerate.utils import send_to_device + from accelerate.utils.memory import clear_device_cache + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# YiYi Notes: copied from modeling_utils.py (decide later where to put this) +def get_memory_footprint(self, return_buffers=True): + r""" + Get the memory footprint of a model. This will return the memory footprint of the current model in bytes. Useful to + benchmark the memory footprint of the current model and design some tests. Solution inspired from the PyTorch + discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2 + + Arguments: + return_buffers (`bool`, *optional*, defaults to `True`): + Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers are + tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch norm + layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2 + """ + mem = sum([param.nelement() * param.element_size() for param in self.parameters()]) + if return_buffers: + mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()]) + mem = mem + mem_bufs + return mem + + +class CustomOffloadHook(ModelHook): + """ + A hook that offloads a model on the CPU until its forward pass is called. It ensures the model and its inputs are + on the given device. Optionally offloads other models to the CPU before the forward pass is called. + + Args: + execution_device(`str`, `int` or `torch.device`, *optional*): + The device on which the model should be executed. Will default to the MPS device if it's available, then + GPU 0 if there is a GPU, and finally to the CPU. + """ + + def __init__( + self, + execution_device: Optional[Union[str, int, torch.device]] = None, + other_hooks: Optional[List["UserCustomOffloadHook"]] = None, + offload_strategy: Optional["AutoOffloadStrategy"] = None, + ): + self.execution_device = execution_device if execution_device is not None else PartialState().default_device + self.other_hooks = other_hooks + self.offload_strategy = offload_strategy + self.model_id = None + + def set_strategy(self, offload_strategy: "AutoOffloadStrategy"): + self.offload_strategy = offload_strategy + + def add_other_hook(self, hook: "UserCustomOffloadHook"): + """ + Add a hook to the list of hooks to consider for offloading. + """ + if self.other_hooks is None: + self.other_hooks = [] + self.other_hooks.append(hook) + + def init_hook(self, module): + return module.to("cpu") + + def pre_forward(self, module, *args, **kwargs): + if module.device != self.execution_device: + if self.other_hooks is not None: + hooks_to_offload = [hook for hook in self.other_hooks if hook.model.device == self.execution_device] + # offload all other hooks + import time + + # YiYi Notes: only logging time for now to monitor the overhead of offloading strategy (remove later) + start_time = time.perf_counter() + if self.offload_strategy is not None: + hooks_to_offload = self.offload_strategy( + hooks=hooks_to_offload, + model_id=self.model_id, + model=module, + execution_device=self.execution_device, + ) + end_time = time.perf_counter() + logger.info( + f" time taken to apply offload strategy for {self.model_id}: {(end_time - start_time):.2f} seconds" + ) + + for hook in hooks_to_offload: + logger.info( + f"moving {self.model_id} to {self.execution_device}, offloading {hook.model_id} to cpu" + ) + hook.offload() + + if hooks_to_offload: + clear_device_cache() + module.to(self.execution_device) + return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device) + + +class UserCustomOffloadHook: + """ + A simple hook grouping a model and a `CustomOffloadHook`, which provides easy APIs for to call the init method of + the hook or remove it entirely. + """ + + def __init__(self, model_id, model, hook): + self.model_id = model_id + self.model = model + self.hook = hook + + def offload(self): + self.hook.init_hook(self.model) + + def attach(self): + add_hook_to_module(self.model, self.hook) + self.hook.model_id = self.model_id + + def remove(self): + remove_hook_from_module(self.model) + self.hook.model_id = None + + def add_other_hook(self, hook: "UserCustomOffloadHook"): + self.hook.add_other_hook(hook) + + +def custom_offload_with_hook( + model_id: str, + model: torch.nn.Module, + execution_device: Union[str, int, torch.device] = None, + offload_strategy: Optional["AutoOffloadStrategy"] = None, +): + hook = CustomOffloadHook(execution_device=execution_device, offload_strategy=offload_strategy) + user_hook = UserCustomOffloadHook(model_id=model_id, model=model, hook=hook) + user_hook.attach() + return user_hook + + +class AutoOffloadStrategy: + """ + Offload strategy that should be used with `CustomOffloadHook` to automatically offload models to the CPU based on + the available memory on the device. + """ + + def __init__(self, size_estimation_margin=0.1): + self.size_estimation_margin = size_estimation_margin + + def __call__(self, hooks, model_id, model, execution_device): + if len(hooks) == 0: + return [] + + current_module_size = get_memory_footprint(model) + current_module_size *= 1 + self.size_estimation_margin + + mem_on_device = torch.cuda.mem_get_info(execution_device.index)[0] + if current_module_size < mem_on_device: + return [] + + min_memory_offload = current_module_size - mem_on_device + logger.info(f" search for models to offload in order to free up {min_memory_offload / 1024**3:.2f} GB memory") + + # exlucde models that's not currently loaded on the device + module_sizes = dict( + sorted( + {hook.model_id: get_memory_footprint(hook.model) for hook in hooks}.items(), + key=lambda x: x[1], + reverse=True, + ) + ) + + def search_best_candidate(module_sizes, min_memory_offload): + """ + search the optimal combination of models to offload to cpu, given a dictionary of module sizes and a + minimum memory offload size. the combination of models should add up to the smallest modulesize that is + larger than `min_memory_offload` + """ + model_ids = list(module_sizes.keys()) + best_candidate = None + best_size = float("inf") + for r in range(1, len(model_ids) + 1): + for candidate_model_ids in combinations(model_ids, r): + candidate_size = sum( + module_sizes[candidate_model_id] for candidate_model_id in candidate_model_ids + ) + if candidate_size < min_memory_offload: + continue + else: + if best_candidate is None or candidate_size < best_size: + best_candidate = candidate_model_ids + best_size = candidate_size + + return best_candidate + + best_offload_model_ids = search_best_candidate(module_sizes, min_memory_offload) + + if best_offload_model_ids is None: + # if no combination is found, meaning that we cannot meet the memory requirement, offload all models + logger.warning("no combination of models to offload to cpu is found, offloading all models") + hooks_to_offload = hooks + else: + hooks_to_offload = [hook for hook in hooks if hook.model_id in best_offload_model_ids] + + return hooks_to_offload + + +class ModelManager: + def __init__(self): + self.models = OrderedDict() + self.model_hooks = None + self._auto_offload_enabled = False + + def add(self, model_id, model): + if model_id not in self.models: + self.models[model_id] = model + if self._auto_offload_enabled: + self.enable_auto_cpu_offload(self._auto_offload_device) + + def remove(self, model_id): + self.models.pop(model_id) + if self._auto_offload_enabled: + self.enable_auto_cpu_offload(self._auto_offload_device) + + def enable_auto_cpu_offload(self, device, size_estimation_margin=0.1): + for model_id, model in self.models.items(): + if isinstance(model, torch.nn.Module) and hasattr(model, "_hf_hook"): + remove_hook_from_module(model, recurse=True) + + self.disable_auto_cpu_offload() + offload_strategy = AutoOffloadStrategy(size_estimation_margin=size_estimation_margin) + device = torch.device(device) + if device.index is None: + device = torch.device(f"{device.type}:{0}") + all_hooks = [] + for model_id, model in self.models.items(): + hook = custom_offload_with_hook(model_id, model, device, offload_strategy=offload_strategy) + all_hooks.append(hook) + + for hook in all_hooks: + other_hooks = [h for h in all_hooks if h is not hook] + for other_hook in other_hooks: + if other_hook.hook.execution_device == hook.hook.execution_device: + hook.add_other_hook(other_hook) + + self.model_hooks = all_hooks + self._auto_offload_enabled = True + self._auto_offload_device = device + + def disable_auto_cpu_offload(self): + if self.model_hooks is None: + self._auto_offload_enabled = False + return + + for hook in self.model_hooks: + hook.offload() + hook.remove() + if self.model_hooks: + clear_device_cache() + self.model_hooks = None + self._auto_offload_enabled = False + + def __repr__(self): + col_widths = { + "id": max(15, max(len(id) for id in self.models.keys())), + "class": max(25, max(len(model.__class__.__name__) for model in self.models.values())), + "device": 10, + "dtype": 15, + "size": 10, + } + + # Create the header + sep_line = "=" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n" + dash_line = "-" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n" + + output = "ModelManager:\n" + sep_line + + # Column headers + output += f"{'Model ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}} | " + output += f"{'Device':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | Size (GB) \n" + output += dash_line + + # Model entries + for model_id, model in self.models.items(): + device = model.device + dtype = model.dtype + size_bytes = get_memory_footprint(model) + size_gb = size_bytes / (1024**3) + + output += f"{model_id:<{col_widths['id']}} | {model.__class__.__name__:<{col_widths['class']}} | " + output += f"{str(device):<{col_widths['device']}} | {str(dtype):<{col_widths['dtype']}} | {size_gb:.2f}\n" + + output += sep_line + return output diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 47229594363a..6e614e9c9522 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -58,7 +58,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. From 72d9a81d996a6306f4f6b96ce35b2b2dcd766633 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 31 Dec 2024 09:54:46 +0100 Subject: [PATCH 027/170] components manager --- ...model_manager.py => components_manager.py} | 111 ++++++++++++------ 1 file changed, 78 insertions(+), 33 deletions(-) rename src/diffusers/pipelines/{model_manager.py => components_manager.py} (73%) diff --git a/src/diffusers/pipelines/model_manager.py b/src/diffusers/pipelines/components_manager.py similarity index 73% rename from src/diffusers/pipelines/model_manager.py rename to src/diffusers/pipelines/components_manager.py index ea7e552869d4..40b4a454200b 100644 --- a/src/diffusers/pipelines/model_manager.py +++ b/src/diffusers/pipelines/components_manager.py @@ -227,27 +227,37 @@ def search_best_candidate(module_sizes, min_memory_offload): return hooks_to_offload -class ModelManager: +class ComponentsManager: def __init__(self): - self.models = OrderedDict() + self.components = OrderedDict() self.model_hooks = None self._auto_offload_enabled = False - def add(self, model_id, model): - if model_id not in self.models: - self.models[model_id] = model + def add(self, name, component): + if name not in self.components: + self.components[name] = component if self._auto_offload_enabled: self.enable_auto_cpu_offload(self._auto_offload_device) - def remove(self, model_id): - self.models.pop(model_id) + def remove(self, name): + self.components.pop(name) if self._auto_offload_enabled: self.enable_auto_cpu_offload(self._auto_offload_device) + + def get(self, names: Union[str, List[str]]): + if isinstance(names, str): + if names not in self.components: + raise ValueError(f"Component '{names}' not found in ComponentsManager") + return self.components[names] + elif isinstance(names, list): + return {n: self.components[n] for n in names} + else: + raise ValueError(f"Invalid type for names: {type(names)}") def enable_auto_cpu_offload(self, device, size_estimation_margin=0.1): - for model_id, model in self.models.items(): - if isinstance(model, torch.nn.Module) and hasattr(model, "_hf_hook"): - remove_hook_from_module(model, recurse=True) + for name, component in self.components.items(): + if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"): + remove_hook_from_module(component, recurse=True) self.disable_auto_cpu_offload() offload_strategy = AutoOffloadStrategy(size_estimation_margin=size_estimation_margin) @@ -255,9 +265,10 @@ def enable_auto_cpu_offload(self, device, size_estimation_margin=0.1): if device.index is None: device = torch.device(f"{device.type}:{0}") all_hooks = [] - for model_id, model in self.models.items(): - hook = custom_offload_with_hook(model_id, model, device, offload_strategy=offload_strategy) - all_hooks.append(hook) + for name, component in self.components.items(): + if isinstance(component, torch.nn.Module): + hook = custom_offload_with_hook(name, component, device, offload_strategy=offload_strategy) + all_hooks.append(hook) for hook in all_hooks: other_hooks = [h for h in all_hooks if h is not hook] @@ -284,33 +295,67 @@ def disable_auto_cpu_offload(self): def __repr__(self): col_widths = { - "id": max(15, max(len(id) for id in self.models.keys())), - "class": max(25, max(len(model.__class__.__name__) for model in self.models.values())), + "id": max(15, max(len(id) for id in self.components.keys())), + "class": max(25, max(len(component.__class__.__name__) for component in self.components.values())), "device": 10, "dtype": 15, "size": 10, } - # Create the header + # Create the header lines sep_line = "=" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n" dash_line = "-" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n" - output = "ModelManager:\n" + sep_line - - # Column headers - output += f"{'Model ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}} | " - output += f"{'Device':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | Size (GB) \n" - output += dash_line - - # Model entries - for model_id, model in self.models.items(): - device = model.device - dtype = model.dtype - size_bytes = get_memory_footprint(model) - size_gb = size_bytes / (1024**3) + output = "Components:\n" + sep_line + + # Separate components into models and others + models = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)} + others = {k: v for k, v in self.components.items() if not isinstance(v, torch.nn.Module)} + + # Models section + if models: + output += "Models:\n" + dash_line + # Column headers + output += f"{'Model ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}} | " + output += f"{'Device':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | Size (GB) \n" + output += dash_line + + # Model entries + for name, component in models.items(): + device = component.device + dtype = component.dtype + size_bytes = get_memory_footprint(component) + size_gb = size_bytes / (1024**3) + + output += f"{name:<{col_widths['id']}} | {component.__class__.__name__:<{col_widths['class']}} | " + output += f"{str(device):<{col_widths['device']}} | {str(dtype):<{col_widths['dtype']}} | {size_gb:.2f}\n" + output += dash_line + + # Other components section + if others: + if models: # Add extra newline if we had models section + output += "\n" + output += "Other Components:\n" + dash_line + # Column headers for other components + output += f"{'Component ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}}\n" + output += dash_line + + # Other component entries + for name, component in others.items(): + output += f"{name:<{col_widths['id']}} | {component.__class__.__name__:<{col_widths['class']}}\n" + output += dash_line - output += f"{model_id:<{col_widths['id']}} | {model.__class__.__name__:<{col_widths['class']}} | " - output += f"{str(device):<{col_widths['device']}} | {str(dtype):<{col_widths['dtype']}} | {size_gb:.2f}\n" - - output += sep_line return output + + def add_from_pretrained(self, pretrained_model_name_or_path, **kwargs): + from ..pipelines.pipeline_utils import DiffusionPipeline + pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs) + for name, component in pipe.components.items(): + if name not in self.components and component is not None: + self.add(name, component) + elif name in self.components: + logger.warning( + f"Component '{name}' already exists in ComponentsManager and will not be added. To add it, either:\n" + f"1. remove the existing component with remove('{name}')\n" + f"2. Use a different name: add('{name}_2', component)" + ) From 10d4a775f1c9490f96ef6d13f719faa5bd7a0461 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 31 Dec 2024 09:55:50 +0100 Subject: [PATCH 028/170] style --- src/diffusers/pipelines/components_manager.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/components_manager.py b/src/diffusers/pipelines/components_manager.py index 40b4a454200b..95fd2ecd7de1 100644 --- a/src/diffusers/pipelines/components_manager.py +++ b/src/diffusers/pipelines/components_manager.py @@ -243,7 +243,7 @@ def remove(self, name): self.components.pop(name) if self._auto_offload_enabled: self.enable_auto_cpu_offload(self._auto_offload_device) - + def get(self, names: Union[str, List[str]]): if isinstance(names, str): if names not in self.components: @@ -328,7 +328,9 @@ def __repr__(self): size_gb = size_bytes / (1024**3) output += f"{name:<{col_widths['id']}} | {component.__class__.__name__:<{col_widths['class']}} | " - output += f"{str(device):<{col_widths['device']}} | {str(dtype):<{col_widths['dtype']}} | {size_gb:.2f}\n" + output += ( + f"{str(device):<{col_widths['device']}} | {str(dtype):<{col_widths['dtype']}} | {size_gb:.2f}\n" + ) output += dash_line # Other components section @@ -349,6 +351,7 @@ def __repr__(self): def add_from_pretrained(self, pretrained_model_name_or_path, **kwargs): from ..pipelines.pipeline_utils import DiffusionPipeline + pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs) for name, component in pipe.components.items(): if name not in self.components and component is not None: From 27dde51de8526b228ef18eea17143b6bd906117f Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 31 Dec 2024 18:06:44 +0100 Subject: [PATCH 029/170] add output arg to run_blocks --- .../pipelines/modular_pipeline_builder.py | 22 ++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/modular_pipeline_builder.py b/src/diffusers/pipelines/modular_pipeline_builder.py index faedada96073..ce4f667b4327 100644 --- a/src/diffusers/pipelines/modular_pipeline_builder.py +++ b/src/diffusers/pipelines/modular_pipeline_builder.py @@ -69,7 +69,12 @@ def get_intermediate(self, key: str, default: Any = None) -> Any: return self.intermediates.get(key, default) def get_output(self, key: str, default: Any = None) -> Any: - return self.outputs.get(key, default) + if key in self.outputs: + return self.outputs[key] + elif key in self.intermediates: + return self.intermediates[key] + else: + return default def to_dict(self) -> Dict[str, Any]: return {**self.__dict__, "inputs": self.inputs, "intermediates": self.intermediates, "outputs": self.outputs} @@ -1132,7 +1137,7 @@ def replace_blocks(self, pipeline_blocks, at: int): # Remove the old blocks self.remove_blocks(indices_to_remove) - def run_blocks(self, state: PipelineState = None, **kwargs): + def run_blocks(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs): """ Run one or more blocks in sequence, optionally you can pass a previous pipeline state. """ @@ -1174,7 +1179,18 @@ def run_blocks(self, state: PipelineState = None, **kwargs): raise self.maybe_free_model_hooks() - return state + if output is None: + return state + + if isinstance(output, str): + return state.get_output(output) + elif isinstance(output, (list, tuple)): + outputs = {} + for output_name in output: + outputs[output_name] = state.get_output(output_name) + return outputs + else: + raise ValueError(f"Output '{output}' is not a valid output type") def run_pipeline(self, **kwargs): state = PipelineState() From 8c02572e167c509dd8b1f3ebf5ed259a7007272f Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 31 Dec 2024 20:08:53 +0100 Subject: [PATCH 030/170] add memory_reserve_margin arg to auto offload --- src/diffusers/pipelines/components_manager.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/components_manager.py b/src/diffusers/pipelines/components_manager.py index 95fd2ecd7de1..d6a4b5958750 100644 --- a/src/diffusers/pipelines/components_manager.py +++ b/src/diffusers/pipelines/components_manager.py @@ -29,6 +29,7 @@ from accelerate.state import PartialState from accelerate.utils import send_to_device from accelerate.utils.memory import clear_device_cache + from accelerate.utils.modeling import convert_file_size_to_int logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -166,17 +167,17 @@ class AutoOffloadStrategy: the available memory on the device. """ - def __init__(self, size_estimation_margin=0.1): - self.size_estimation_margin = size_estimation_margin + def __init__(self, memory_reserve_margin="3GB"): + self.memory_reserve_margin = convert_file_size_to_int(memory_reserve_margin) def __call__(self, hooks, model_id, model, execution_device): if len(hooks) == 0: return [] current_module_size = get_memory_footprint(model) - current_module_size *= 1 + self.size_estimation_margin mem_on_device = torch.cuda.mem_get_info(execution_device.index)[0] + mem_on_device = mem_on_device - self.memory_reserve_margin if current_module_size < mem_on_device: return [] @@ -254,13 +255,13 @@ def get(self, names: Union[str, List[str]]): else: raise ValueError(f"Invalid type for names: {type(names)}") - def enable_auto_cpu_offload(self, device, size_estimation_margin=0.1): + def enable_auto_cpu_offload(self, device, memory_reserve_margin="3GB"): for name, component in self.components.items(): if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"): remove_hook_from_module(component, recurse=True) self.disable_auto_cpu_offload() - offload_strategy = AutoOffloadStrategy(size_estimation_margin=size_estimation_margin) + offload_strategy = AutoOffloadStrategy(memory_reserve_margin=memory_reserve_margin) device = torch.device(device) if device.index is None: device = torch.device(f"{device.type}:{0}") From a09ca7f27ec77b1791ffeda9ee0c53b29939f55d Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 1 Jan 2025 21:43:20 +0100 Subject: [PATCH 031/170] refactors: block __init__ no longer accept args. remove update_states from pipeline blocks, add update_states to modularpipeline, remove multi-block support for modular pipeline, remove offload support on modular pipeline --- src/diffusers/guider.py | 10 +- .../pipelines/modular_pipeline_builder.py | 791 +++--------------- .../pipeline_stable_diffusion_xl_modular.py | 123 +-- 3 files changed, 158 insertions(+), 766 deletions(-) diff --git a/src/diffusers/guider.py b/src/diffusers/guider.py index 6be2f47199ab..e5942ff560b1 100644 --- a/src/diffusers/guider.py +++ b/src/diffusers/guider.py @@ -356,7 +356,7 @@ def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]): self._guidance_rescale = guidance_rescale self._batch_size = batch_size if not hasattr(pipeline, "original_attn_proc") or pipeline.original_attn_proc is None: - self.original_attn_proc = pipeline.unet.attn_processors + pipeline.original_attn_proc = pipeline.unet.attn_processors self._set_pag_attn_processor( model=pipeline.unet if hasattr(pipeline, "unet") else pipeline.transformer, pag_applied_layers=self.pag_applied_layers, @@ -366,11 +366,11 @@ def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]): def reset_guider(self, pipeline): if ( self.do_perturbed_attention_guidance - and hasattr(self, "original_attn_proc") - and self.original_attn_proc is not None + and hasattr(pipeline, "original_attn_proc") + and pipeline.original_attn_proc is not None ): - pipeline.unet.set_attn_processor(self.original_attn_proc) - self.original_attn_proc = None + pipeline.unet.set_attn_processor(pipeline.original_attn_proc) + pipeline.original_attn_proc = None def maybe_update_guider(self, pipeline, timestep): pass diff --git a/src/diffusers/pipelines/modular_pipeline_builder.py b/src/diffusers/pipelines/modular_pipeline_builder.py index ce4f667b4327..d91dbe6f1f81 100644 --- a/src/diffusers/pipelines/modular_pipeline_builder.py +++ b/src/diffusers/pipelines/modular_pipeline_builder.py @@ -16,7 +16,7 @@ import warnings from collections import OrderedDict from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Tuple, Union import torch from tqdm.auto import tqdm @@ -27,9 +27,6 @@ is_accelerate_version, logging, ) -from ..utils.hub_utils import validate_hf_hub_args -from .pipeline_loading_utils import _fetch_class_library_tuple -from .pipeline_utils import DiffusionPipeline if is_accelerate_available(): @@ -102,10 +99,10 @@ def format_value(v): class PipelineBlock: + # YiYi Notes: do we need this? + # pipelie block should set the default value for all expected config/components, so maybe we do not need to explicitly set the list expected_components = [] - expected_auxiliaries = [] expected_configs = [] - _model_cpu_offload_seq = None @property def inputs(self) -> Tuple[Tuple[str, Any], ...]: @@ -120,110 +117,11 @@ def intermediates_inputs(self) -> List[str]: def intermediates_outputs(self) -> List[str]: return [] - @property - def model_cpu_offload_seq(self): - """ - adjust the model_cpu_offload_seq to reflect actual components loaded in the block - """ - - model_cpu_offload_seq = [] - block_component_names = [k for k, v in self.components.items() if isinstance(v, torch.nn.Module)] - if len(block_component_names) == 0: - return None - if len(block_component_names) == 1: - return block_component_names[0] - else: - if self._model_cpu_offload_seq is None: - raise ValueError( - f"Block {self.__class__.__name__} has multiple components but no model_cpu_offload_seq specified" - ) - model_cpu_offload_seq = [m for m in self._model_cpu_offload_seq.split("->") if m in block_component_names] - remaining = [m for m in block_component_names if m not in model_cpu_offload_seq] - if remaining: - logger.warning( - f"Block {self.__class__.__name__} has components {remaining} that are not in model_cpu_offload_seq {self._model_cpu_offload_seq}" - ) - return "->".join(model_cpu_offload_seq) - - def update_states(self, **kwargs): - """ - Update components and configs after instance creation. Auxiliaries (e.g. image_processor) should be defined for - each pipeline block, does not need to be updated by users. Logs if existing non-None states are being - overwritten. - - Args: - **kwargs: Keyword arguments containing components, or configs to add/update. - e.g. pipeline_block.update_states(unet=unet1, vae=None) - """ - # Add expected components - for component_name in self.expected_components: - if component_name in kwargs: - if component_name in self.components and self.components[component_name] is not None: - if id(self.components[component_name]) != id(kwargs[component_name]): - logger.info( - f"Overwriting existing component '{component_name}' " - f"(type: {type(self.components[component_name]).__name__}) " - f"with new value (type: {type(kwargs[component_name]).__name__})" - ) - self.components[component_name] = kwargs.pop(component_name) - - # Add expected configs - for config_name in self.expected_configs: - if config_name in kwargs: - if config_name in self.configs and self.configs[config_name] is not None: - if self.configs[config_name] != kwargs[config_name]: - logger.info( - f"Overwriting existing config '{config_name}' " - f"(value: {self.configs[config_name]}) " - f"with new value ({kwargs[config_name]})" - ) - self.configs[config_name] = kwargs.pop(config_name) - - def __init__(self, **kwargs): + def __init__(self): self.components: Dict[str, Any] = {} self.auxiliaries: Dict[str, Any] = {} self.configs: Dict[str, Any] = {} - self.update_states(**kwargs) - - # YiYi notes, does pipeline block need "states"? it is not going to be used on its own - # TODO: address existing components -> overwrite or not? currently overwrite - def add_states_from_pipe(self, pipe: DiffusionPipeline, **kwargs): - """ - add components/auxiliaries/configs from a diffusion pipeline object. - - Args: - pipe: A `[DiffusionPipeline]` object. - **kwargs: Additional states to update, these take precedence over pipe values. - - Returns: - PipelineBlock: An instance loaded with the pipeline's components and configurations. - """ - states_to_update = {} - - # Get components - prefer kwargs over pipe values - for component_name in self.expected_components: - if component_name in kwargs: - states_to_update[component_name] = kwargs.pop(component_name) - elif component_name in pipe.components: - states_to_update[component_name] = pipe.components[component_name] - - # Get configs - prefer kwargs over pipe values - pipe_config = dict(pipe.config) - for config_name in self.expected_configs: - if config_name in kwargs: - states_to_update[config_name] = kwargs.pop(config_name) - elif config_name in pipe_config: - states_to_update[config_name] = pipe_config[config_name] - - # Update all states at once - self.update_states(**states_to_update) - - @validate_hf_hub_args - def add_states_from_pretrained(self, pretrained_model_or_path, **kwargs): - base_pipeline = DiffusionPipeline.from_pretrained(pretrained_model_or_path, **kwargs) - self.add_states_from_pipe(base_pipeline, **kwargs) - def __call__(self, pipeline, state: PipelineState) -> PipelineState: raise NotImplementedError("__call__ method must be implemented in subclasses") @@ -238,14 +136,6 @@ def __repr__(self): f"{k}={type(self.components[k]).__name__}" if k in loaded_components else f"{k}" for k in all_components ) - # Auxiliaries section - expected_auxiliaries = set(getattr(self, "expected_auxiliaries", [])) - loaded_auxiliaries = set(self.auxiliaries.keys()) - all_auxiliaries = sorted(expected_auxiliaries | loaded_auxiliaries) - auxiliaries = ", ".join( - f"{k}={type(self.auxiliaries[k]).__name__}" if k in loaded_auxiliaries else f"{k}" for k in all_auxiliaries - ) - # Configs section expected_configs = set(getattr(self, "expected_configs", [])) loaded_configs = set(self.configs.keys()) @@ -263,7 +153,6 @@ def __repr__(self): return ( f"{class_name}(\n" f" components: {components}\n" - f" auxiliaries: {auxiliaries}\n" f" configs: {configs}\n" f" blocks: {blocks}\n" f" inputs: {inputs}\n" @@ -303,7 +192,6 @@ class MultiPipelineBlocks: block_classes = [] block_prefixes = [] - _model_cpu_offload_seq = None @property def expected_components(self): @@ -314,15 +202,6 @@ def expected_components(self): expected_components.append(component) return expected_components - @property - def expected_auxiliaries(self): - expected_auxiliaries = [] - for block in self.blocks.values(): - for auxiliary in block.expected_auxiliaries: - if auxiliary not in expected_auxiliaries: - expected_auxiliaries.append(auxiliary) - return expected_auxiliaries - @property def expected_configs(self): expected_configs = [] @@ -332,11 +211,11 @@ def expected_configs(self): expected_configs.append(config) return expected_configs - def __init__(self, **kwargs): + def __init__(self): blocks = OrderedDict() for block_prefix, block_cls in zip(self.block_prefixes, self.block_classes): block_name = f"{block_prefix}_step" if block_prefix != "" else "step" - blocks[block_name] = block_cls(**kwargs) + blocks[block_name] = block_cls() self.blocks = blocks # YiYi TODO: address the case where multiple blocks have the same component/auxiliary/config; give out warning etc @@ -345,7 +224,12 @@ def components(self): # Combine components from all blocks components = {} for block_name, block in self.blocks.items(): - components.update(block.components) + for key, value in block.components.items(): + # Only update if: + # 1. Key doesn't exist yet in components, OR + # 2. New value is not None + if key not in components or value is not None: + components[key] = value return components @property @@ -383,95 +267,6 @@ def model_cpu_offload_seq(self): def __call__(self, pipeline, state): raise NotImplementedError("__call__ method must be implemented in subclasses") - def update_states(self, **kwargs): - """ - Update states for each block with support for block-specific kwargs. - - Args: - **kwargs: Can include both general kwargs (e.g., 'unet') and - block-specific kwargs (e.g., 'img2img_step_unet') - - Example: - pipeline.update_states( - img2img_step_unet=unet2, # Only for img2img_step step_unet=unet1, # Only for step vae=vae1 # For any - block that expects vae - ) - """ - for block_name, block in self.blocks.items(): - # Prepare block-specific kwargs - if isinstance(block, PipelineBlock): - block_kwargs = {} - - # Check for block-specific kwargs first (e.g., 'img2img_unet') - prefix = f"{block_name.replace('_step', '')}_" - for key, value in kwargs.items(): - if key.startswith(prefix): - # Remove prefix and add to block kwargs - block_kwargs[key[len(prefix) :]] = value - - # For any expected component/auxiliary/config not found with prefix, - # fall back to general kwargs - for name in ( - block.expected_components - + - # block.expected_auxiliaries + - block.expected_configs - ): - if name not in block_kwargs: - if name in kwargs: - block_kwargs[name] = kwargs[name] - elif isinstance(block, MultiPipelineBlocks): - block_kwargs = kwargs - else: - raise ValueError(f"Unsupported block type: {type(block).__name__}") - - # Update the block with its specific kwargs - block.update_states(**block_kwargs) - - def add_states_from_pipe(self, pipe: DiffusionPipeline, **kwargs): - """ - Load components from pipe with support for block-specific kwargs. - - Args: - pipe: DiffusionPipeline object - **kwargs: Can include both general kwargs (e.g., 'unet') and - block-specific kwargs (e.g., 'img2img_unet' for 'img2img_step') - """ - for block_name, block in self.blocks.items(): - # Handle different block types - if isinstance(block, PipelineBlock): - block_kwargs = {} - - # Check for block-specific kwargs first (e.g., 'img2img_unet') - prefix = f"{block_name.replace('_step', '')}_" - for key, value in kwargs.items(): - if key.startswith(prefix): - # Remove prefix and add to block kwargs - block_kwargs[key[len(prefix) :]] = value - - # For any expected component/auxiliary/config not found with prefix, - # fall back to general kwargs - for name in ( - block.expected_components - + - # block.expected_auxiliaries + - block.expected_configs - ): - if name not in block_kwargs: - if name in kwargs: - block_kwargs[name] = kwargs[name] - elif isinstance(block, MultiPipelineBlocks): - block_kwargs = kwargs - else: - raise ValueError(f"Unsupported block type: {type(block).__name__}") - - # Load the block with its specific kwargs - block.add_states_from_pipe(pipe, **block_kwargs) - - def add_states_from_pretrained(self, pretrained_model_or_path, **kwargs): - base_pipeline = DiffusionPipeline.from_pretrained(pretrained_model_or_path, **kwargs) - self.add_states_from_pipe(base_pipeline, **kwargs) - def __repr__(self): class_name = self.__class__.__name__ @@ -485,12 +280,8 @@ def __repr__(self): ) # Auxiliaries section - expected_auxiliaries = set(getattr(self, "expected_auxiliaries", [])) - loaded_auxiliaries = set(self.auxiliaries.keys()) - all_auxiliaries = sorted(expected_auxiliaries | loaded_auxiliaries) auxiliaries_str = " Auxiliaries:\n" + "\n".join( - f" - {k}={type(self.auxiliaries[k]).__name__}" if k in loaded_auxiliaries else f" - {k}" - for k in all_auxiliaries + f" - {k}={type(v).__name__}" for k, v in self.auxiliaries.items() ) # Configs section @@ -527,6 +318,8 @@ def __repr__(self): ) +# YiYi TODO: remove the trigger input logic and keep it more flexible and less convenient: +# user will need to explicitly write the dispatch logic in __call__ for each subclass of this class AutoPipelineBlocks(MultiPipelineBlocks): """ A class that automatically selects which block to run based on trigger inputs. @@ -541,8 +334,8 @@ class AutoPipelineBlocks(MultiPipelineBlocks): block_prefixes = [] block_trigger_inputs = [] - def __init__(self, **kwargs): - super().__init__(**kwargs) + def __init__(self): + super().__init__() self.__post_init__() def __post_init__(self): @@ -610,12 +403,6 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: logger.error(error_msg) raise - @property - def model_cpu_offload_seq(self): - default_block = self.trigger_to_block_map.get(None) - - return default_block.model_cpu_offload_seq - def __repr__(self): class_name = self.__class__.__name__ @@ -629,12 +416,8 @@ def __repr__(self): ) # Auxiliaries section - expected_auxiliaries = set(getattr(self, "expected_auxiliaries", [])) - loaded_auxiliaries = set(self.auxiliaries.keys()) - all_auxiliaries = sorted(expected_auxiliaries | loaded_auxiliaries) auxiliaries_str = " Auxiliaries:\n" + "\n".join( - f" - {k}={type(self.auxiliaries[k]).__name__}" if k in loaded_auxiliaries else f" - {k}" - for k in all_auxiliaries + f" - {k}={type(v).__name__}" for k, v in self.auxiliaries.items() ) # Configs section @@ -733,37 +516,6 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: raise return pipeline, state - @property - def model_cpu_offload_seq(self): - model_cpu_offload_seq = [] - - for block_name, block in self.blocks.items(): - block_components = [k for k, v in block.components.items() if isinstance(v, torch.nn.Module)] - if len(block_components) == 0: - continue - if len(block_components) == 1: - if block_components[0] in model_cpu_offload_seq: - model_cpu_offload_seq.remove(block_components[0]) - model_cpu_offload_seq.append(block_components[0]) - else: - if block.model_cpu_offload_seq is None: - raise ValueError( - f"Block {block_name}:{block.__class__.__name__} has multiple components {block_components} but no model_cpu_offload_seq specified" - ) - for model_str in block.model_cpu_offload_seq.split("->"): - if model_str in block_components: - # if it is already in the list,remove previous occurence and add to the end - if model_str in model_cpu_offload_seq: - model_cpu_offload_seq.remove(model_str) - model_cpu_offload_seq.append(model_str) - block_components.remove(model_str) - if len(block_components) > 0: - logger.warning( - f"Block {block_name}:{block.__class__.__name__} has components {block_components} that are not in model_cpu_offload_seq {block.model_cpu_offload_seq}" - ) - - return "->".join(model_cpu_offload_seq) - def __repr__(self): class_name = self.__class__.__name__ @@ -777,12 +529,8 @@ def __repr__(self): ) # Auxiliaries section - expected_auxiliaries = set(getattr(self, "expected_auxiliaries", [])) - loaded_auxiliaries = set(self.auxiliaries.keys()) - all_auxiliaries = sorted(expected_auxiliaries | loaded_auxiliaries) auxiliaries_str = " Auxiliaries:\n" + "\n".join( - f" - {k}={type(self.auxiliaries[k]).__name__}" if k in loaded_auxiliaries else f" - {k}" - for k in all_auxiliaries + f" - {k}={type(v).__name__}" for k, v in self.auxiliaries.items() ) # Configs section @@ -842,31 +590,21 @@ class ModularPipelineBuilder(ConfigMixin): """ config_name = "model_index.json" - model_cpu_offload_seq = None - hf_device_map = None _exclude_from_cpu_offload = [] - default_pipeline_blocks = [] - def __init__(self): - super().__init__() - self.register_to_config() - self.pipeline_blocks = [] - - # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline.register_modules - def register_modules(self, **kwargs): - for name, module in kwargs.items(): - # retrieve library - if module is None or isinstance(module, (tuple, list)) and module[0] is None: - register_dict = {name: (None, None)} - else: - library, class_name = _fetch_class_library_tuple(module) - register_dict = {name: (library, class_name)} + def __init__(self, block): + self.pipeline_block = block - # save model index config - self.register_to_config(**register_dict) + # add default components from pipeline_block (e.g. guider) + for key, value in block.components.items(): + setattr(self, key, value) - # set models - setattr(self, name, module) + # add default configs from pipeline_block (e.g. force_zeros_for_empty_prompt) + self.register_to_config(**block.configs) + + # add default auxiliaries from pipeline_block (e.g. image_processor) + for key, value in block.auxiliaries.items(): + setattr(self, key, value) @property def device(self) -> torch.device: @@ -920,70 +658,21 @@ def dtype(self) -> torch.dtype: return torch.float32 @property - def components(self) -> Dict[str, Any]: - r""" - The `self.components` property returns all modules needed to initialize the pipeline, as defined by the - pipeline blocks. - - Returns (`dict`): - A dictionary containing all the components defined in the pipeline blocks. - """ + def expected_components(self): + return self.pipeline_block.expected_components - expected_components = set() - for block in self.pipeline_blocks: - expected_components.update(block.components.keys()) + @property + def expected_configs(self): + return self.pipeline_block.expected_configs + @property + def components(self): components = {} - for name in expected_components: + for name in self.expected_components: if hasattr(self, name): components[name] = getattr(self, name) - return components - @property - def auxiliaries(self) -> Dict[str, Any]: - r""" - The `self.auxiliaries` property returns all auxiliaries needed to initialize the pipeline, as defined by the - pipeline blocks. - - Returns (`dict`): - A dictionary containing all the auxiliaries defined in the pipeline blocks. - """ - # First collect all expected auxiliary names from blocks - expected_auxiliaries = set() - for block in self.pipeline_blocks: - expected_auxiliaries.update(block.auxiliaries.keys()) - - # Then fetch the actual auxiliaries from the pipeline - auxiliaries = {} - for name in expected_auxiliaries: - if hasattr(self, name): - auxiliaries[name] = getattr(self, name) - - return auxiliaries - - @property - def configs(self) -> Dict[str, Any]: - r""" - The `self.configs` property returns all configs needed to initialize the pipeline, as defined by the pipeline - blocks. - - Returns (`dict`): - A dictionary containing all the configs defined in the pipeline blocks. - """ - # First collect all expected config names from blocks - expected_configs = set() - for block in self.pipeline_blocks: - expected_configs.update(block.configs.keys()) - - # Then fetch the actual configs from the pipeline's config - configs = {} - for name in expected_configs: - if name in self.config: - configs[name] = self.config[name] - - return configs - # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline.progress_bar def progress_bar(self, iterable=None, total=None): if not hasattr(self, "_progress_bar_config"): @@ -1007,136 +696,6 @@ def set_progress_bar_config(self, **kwargs): def __call__(self, *args, **kwargs): raise NotImplementedError("__call__ is not implemented for ModularPipelineBuilder") - # YiYi Notes: do we need to support multiple blocks? - def remove_blocks(self, indices: Union[int, List[int]]): - """ - Remove one or more blocks from the pipeline by their indices and clean up associated components, configs, and - auxiliaries that are no longer needed by remaining blocks. - - Args: - indices (Union[int, List[int]]): The index or list of indices of blocks to remove - """ - # Convert single index to list - indices = [indices] if isinstance(indices, int) else indices - - # Validate indices - for idx in indices: - if not 0 <= idx < len(self.pipeline_blocks): - raise ValueError( - f"Invalid block index {idx}. Index must be between 0 and {len(self.pipeline_blocks) - 1}" - ) - - # Sort indices in descending order to avoid shifting issues when removing - indices = sorted(indices, reverse=True) - - # Store blocks to be removed - blocks_to_remove = [self.pipeline_blocks[idx] for idx in indices] - - # Remove blocks from pipeline - for idx in indices: - self.pipeline_blocks.pop(idx) - - # Consolidate items to remove from all blocks - components_to_remove = {k: v for block in blocks_to_remove for k, v in block.components.items()} - auxiliaries_to_remove = {k: v for block in blocks_to_remove for k, v in block.auxiliaries.items()} - configs_to_remove = {k: v for block in blocks_to_remove for k, v in block.configs.items()} - - # The properties will now reflect only the remaining blocks - remaining_components = self.components - remaining_auxiliaries = self.auxiliaries - remaining_configs = self.configs - - # Clean up all items that are no longer needed - for component_name in components_to_remove: - if component_name not in remaining_components: - if component_name in self.config: - del self.config[component_name] - if hasattr(self, component_name): - delattr(self, component_name) - - for auxiliary_name in auxiliaries_to_remove: - if auxiliary_name not in remaining_auxiliaries: - if hasattr(self, auxiliary_name): - delattr(self, auxiliary_name) - - for config_name in configs_to_remove: - if config_name not in remaining_configs: - if config_name in self.config: - del self.config[config_name] - - # YiYi Notes: I left all the functionalities to support adding multiple blocks - # but I wonder if it is still needed now we have `SequentialBlocks` and user can always combine them into one before adding to the builder - def add_blocks(self, pipeline_blocks, at: int = -1): - """Add blocks to the pipeline. - - Args: - pipeline_blocks: A single PipelineBlock instance or a list of PipelineBlock instances. - at (int, optional): Index at which to insert the blocks. Defaults to -1 (append at end). - """ - # Convert single block to list for uniform processing - if not isinstance(pipeline_blocks, (list, tuple)): - pipeline_blocks = [pipeline_blocks] - - # Validate insert_at index - if at != -1 and not 0 <= at <= len(self.pipeline_blocks): - raise ValueError(f"Invalid at index {at}. Index must be between 0 and {len(self.pipeline_blocks)}") - - # Consolidate all items from blocks - components_to_add = {} - configs_to_add = {} - auxiliaries_to_add = {} - - # Add blocks in order - for i, block in enumerate(pipeline_blocks): - # Add block to pipeline at specified position - if at == -1: - self.pipeline_blocks.append(block) - else: - self.pipeline_blocks.insert(at + i, block) - - # Collect components that don't already exist - for k, v in block.components.items(): - if not hasattr(self, k) or (getattr(self, k, None) is None and v is not None): - components_to_add[k] = v - - # Collect configs and auxiliaries - configs_to_add.update(block.configs) - auxiliaries_to_add.update(block.auxiliaries) - - # Process all items in batches - if components_to_add: - self.register_modules(**components_to_add) - if configs_to_add: - self.register_to_config(**configs_to_add) - for key, value in auxiliaries_to_add.items(): - setattr(self, key, value) - - def replace_blocks(self, pipeline_blocks, at: int): - """Replace one or more blocks in the pipeline at the specified index. - - Args: - pipeline_blocks: A single PipelineBlock instance or a list of PipelineBlock instances - that will replace existing blocks. - at (int): Index at which to replace the blocks. - """ - # Convert single block to list for uniform processing - if not isinstance(pipeline_blocks, (list, tuple)): - pipeline_blocks = [pipeline_blocks] - - # Validate replace_at index - if not 0 <= at < len(self.pipeline_blocks): - raise ValueError(f"Invalid at index {at}. Index must be between 0 and {len(self.pipeline_blocks) - 1}") - - # Add new blocks first - self.add_blocks(pipeline_blocks, at=at) - - # Calculate indices to remove - # We need to remove the original blocks that are now shifted by the length of pipeline_blocks - indices_to_remove = list(range(at + len(pipeline_blocks), at + len(pipeline_blocks) * 2)) - - # Remove the old blocks - self.remove_blocks(indices_to_remove) - def run_blocks(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs): """ Run one or more blocks in sequence, optionally you can pass a previous pipeline state. @@ -1154,14 +713,14 @@ def run_blocks(self, state: PipelineState = None, output: Union[str, List[str]] for name, default in default_params.items(): if name in input_params: - if name not in self.pipeline_blocks[0].intermediates_inputs: + if name not in self.pipeline_block.intermediates_inputs: state.add_input(name, input_params.pop(name)) else: state.add_input(name, input_params[name]) elif name not in state.inputs: state.add_input(name, default) - for name in self.pipeline_blocks[0].intermediates_inputs: + for name in self.pipeline_block.intermediates_inputs: if name in input_params: state.add_intermediate(name, input_params.pop(name)) @@ -1170,14 +729,12 @@ def run_blocks(self, state: PipelineState = None, output: Union[str, List[str]] logger.warning(f"Unexpected input '{input_params.keys()}' provided. This input will be ignored.") # Run the pipeline with torch.no_grad(): - for block in self.pipeline_blocks: - try: - pipeline, state = block(self, state) - except Exception: - error_msg = f"Error in block: ({block.__class__.__name__}):\n" - logger.error(error_msg) - raise - self.maybe_free_model_hooks() + try: + pipeline, state = self.pipeline_block(self, state) + except Exception: + error_msg = f"Error in block: ({self.pipeline_block.__class__.__name__}):\n" + logger.error(error_msg) + raise if output is None: return state @@ -1192,83 +749,81 @@ def run_blocks(self, state: PipelineState = None, output: Union[str, List[str]] else: raise ValueError(f"Output '{output}' is not a valid output type") - def run_pipeline(self, **kwargs): - state = PipelineState() + def update_states(self, **kwargs): + """ + Update components and configs after instance creation. Auxiliaries (e.g. image_processor) should be defined for + each pipeline block, does not need to be updated by users. Logs if existing non-None components are being + overwritten. - # Make a copy of the input kwargs - input_params = kwargs.copy() + Args: + kwargs (dict): Keyword arguments to update the states. + """ - default_params = self.default_call_parameters + for component_name in self.expected_components: + if component_name in kwargs: + if hasattr(self, component_name) and getattr(self, component_name) is not None: + current_component = getattr(self, component_name) + new_component = kwargs[component_name] - # Add inputs to state, using defaults if not provided - for name, default in default_params.items(): - if name in input_params: - state.add_input(name, input_params.pop(name)) - else: - state.add_input(name, default) + if not isinstance(new_component, current_component.__class__): + logger.info( + f"Overwriting existing component '{component_name}' " + f"(type: {current_component.__class__.__name__}) " + f"with type: {new_component.__class__.__name__})" + ) + elif isinstance(current_component, torch.nn.Module): + if id(current_component) != id(new_component): + logger.info( + f"Overwriting existing component '{component_name}' " + f"(type: {type(current_component).__name__}) " + f"with new value (type: {type(new_component).__name__})" + ) - # Warn about unexpected inputs - if len(input_params) > 0: - logger.warning(f"Unexpected input '{input_params.keys()}' provided. This input will be ignored.") + setattr(self, component_name, kwargs.pop(component_name)) - # Run the pipeline - with torch.no_grad(): - for block in self.pipeline_blocks: - try: - pipeline, state = block(self, state) - except Exception as e: - error_msg = ( - f"\nError in block: ({block.__class__.__name__}):\n" - f"Error details: {str(e)}\n" - f"Stack trace:\n{traceback.format_exc()}" - ) - logger.error(error_msg) - raise - self.maybe_free_model_hooks() - - return state.get_output("images") + configs_to_add = {} + for config_name in self.expected_configs: + if config_name in kwargs: + configs_to_add[config_name] = kwargs.pop(config_name) + self.register_to_config(**configs_to_add) @property def default_call_parameters(self) -> Dict[str, Any]: params = {} - for block in self.pipeline_blocks: - for name, default in block.inputs: - if name not in params: - params[name] = default + for name, default in self.pipeline_block.inputs: + params[name] = default return params def __repr__(self): - output = "CustomPipeline Configuration:\n" + output = "ModularPipeline:\n" output += "==============================\n\n" - # List the blocks used to build the pipeline - output += "Pipeline Blocks:\n" - output += "----------------\n" - for i, block in enumerate(self.pipeline_blocks): - if isinstance(block, MultiPipelineBlocks): - output += f"{i}. {block.__class__.__name__} - (CPU offload seq: {block.model_cpu_offload_seq})\n" - # Add sub-blocks information - for sub_block_name, sub_block in block.blocks.items(): - output += f" • {sub_block_name} ({sub_block.__class__.__name__}) \n" - else: - output += f"{i}. {block.__class__.__name__} - (CPU offload seq: {block.model_cpu_offload_seq})\n" - output += "\n" + output += "Pipeline Block:\n" + output += "--------------\n" + block = self.pipeline_block + if isinstance(block, MultiPipelineBlocks): + output += f"{block.__class__.__name__}\n" + # Add sub-blocks information + for sub_block_name, sub_block in block.blocks.items(): + output += f" • {sub_block_name} ({sub_block.__class__.__name__}) \n" + else: + output += f"{block.__class__.__name__}\n" + output += "\n" - intermediates_str = "" - if hasattr(block, "intermediates_inputs"): - intermediates_str += f"{', '.join(block.intermediates_inputs)}" - - if hasattr(block, "intermediates_outputs"): - if intermediates_str: - intermediates_str += " -> " - else: - intermediates_str += "-> " - intermediates_str += f"{', '.join(block.intermediates_outputs)}" + intermediates_str = "" + if hasattr(block, "intermediates_inputs"): + intermediates_str += f"{', '.join(block.intermediates_inputs)}" + if hasattr(block, "intermediates_outputs"): if intermediates_str: - output += f" {intermediates_str}\n" + intermediates_str += " -> " + else: + intermediates_str += "-> " + intermediates_str += f"{', '.join(block.intermediates_outputs)}" + + if intermediates_str: + output += f" {intermediates_str}\n" - output += "\n" output += "\n" # List the components registered in the pipeline @@ -1281,36 +836,23 @@ def __repr__(self): output += "\n" output += "\n" - # List the auxiliaries registered in the pipeline - output += "Registered Auxiliaries:\n" - output += "----------------------\n" - for name, auxiliary in self.auxiliaries.items(): - output += f"{name}: {type(auxiliary).__name__}\n" - output += "\n" - # List the configs registered in the pipeline output += "Registered Configs:\n" output += "------------------\n" - for name, config in self.configs.items(): + for name, config in self.config.items(): output += f"{name}: {config!r}\n" output += "\n" # List the default call parameters - output += "Default Call Parameters:\n" + output += "Call Parameters:\n" output += "------------------------\n" - params = self.default_call_parameters - for name, default in params.items(): + for name, default in self.default_call_parameters.items(): output += f"{name}: {default!r}\n" - # Add a section for required call parameters: - # intermediate inputs for the first block - output += "\nRequired Call Parameters:\n" + output += "\nRequired intermediate inputs:\n" output += "--------------------------\n" - for name in self.pipeline_blocks[0].intermediates_inputs: + for name in self.pipeline_block.intermediates_inputs: output += f"{name}: \n" - params[name] = "" - - output += "\nNote: These are the default values. Actual values may be different when running the pipeline." return output # YiYi TO-DO: try to unify the to method with the one in DiffusionPipeline @@ -1457,120 +999,3 @@ def module_is_offloaded(module): " `torch_dtype=torch.float16` argument, or use another device for inference." ) return self - - def remove_all_hooks(self): - for _, model in self.components.items(): - if isinstance(model, torch.nn.Module) and hasattr(model, "_hf_hook"): - accelerate.hooks.remove_hook_from_module(model, recurse=True) - self._all_hooks = [] - - def find_model_sequence(self): - pass - - # YiYi notes: assume there is only one pipeline block now (still debating if we want to support multiple pipeline blocks) - @property - def model_cpu_offload_seq(self): - return self.pipeline_blocks[0].model_cpu_offload_seq - - def enable_model_cpu_offload( - self, - gpu_id: Optional[int] = None, - device: Union[torch.device, str] = "cuda", - model_cpu_offload_seq: Optional[str] = None, - ): - r""" - Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared - to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` - method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with - `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. - - Arguments: - gpu_id (`int`, *optional*): - The ID of the accelerator that shall be used in inference. If not specified, it will default to 0. - device (`torch.Device` or `str`, *optional*, defaults to "cuda"): - The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will - default to "cuda". - """ - _exclude_from_cpu_offload = [] # YiYi Notes: this is not used (keep the variable for now) - is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1 - if is_pipeline_device_mapped: - raise ValueError( - "It seems like you have activated a device mapping strategy on the pipeline so calling `enable_model_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_model_cpu_offload()`." - ) - - model_cpu_offload_seq = model_cpu_offload_seq or self.model_cpu_offload_seq - self._model_cpu_offload_seq_used = model_cpu_offload_seq - if model_cpu_offload_seq is None: - raise ValueError( - "Model CPU offload cannot be enabled because no `model_cpu_offload_seq` class attribute is set or passed." - ) - - if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): - from accelerate import cpu_offload_with_hook - else: - raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") - - self.remove_all_hooks() - - torch_device = torch.device(device) - device_index = torch_device.index - - if gpu_id is not None and device_index is not None: - raise ValueError( - f"You have passed both `gpu_id`={gpu_id} and an index as part of the passed device `device`={device}" - f"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of the device: `device`={torch_device.type}" - ) - - # _offload_gpu_id should be set to passed gpu_id (or id in passed `device`) or default to previously set id or default to 0 - self._offload_gpu_id = gpu_id or torch_device.index or getattr(self, "_offload_gpu_id", 0) - - device_type = torch_device.type - device = torch.device(f"{device_type}:{self._offload_gpu_id}") - self._offload_device = device - - self.to("cpu", silence_dtype_warnings=True) - device_mod = getattr(torch, device.type, None) - if hasattr(device_mod, "empty_cache") and device_mod.is_available(): - device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist) - - all_model_components = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)} - - self._all_hooks = [] - hook = None - for model_str in model_cpu_offload_seq.split("->"): - model = all_model_components.pop(model_str, None) - if not isinstance(model, torch.nn.Module): - continue - - _, hook = cpu_offload_with_hook(model, device, prev_module_hook=hook) - self._all_hooks.append(hook) - - # CPU offload models that are not in the seq chain unless they are explicitly excluded - # these models will stay on CPU until maybe_free_model_hooks is called - # some models cannot be in the seq chain because they are iteratively called, such as controlnet - for name, model in all_model_components.items(): - if not isinstance(model, torch.nn.Module): - continue - - if name in _exclude_from_cpu_offload: - model.to(device) - else: - _, hook = cpu_offload_with_hook(model, device) - self._all_hooks.append(hook) - - def maybe_free_model_hooks(self): - r""" - Function that offloads all components, removes all model hooks that were added when using - `enable_model_cpu_offload` and then applies them again. In case the model has not been offloaded this function - is a no-op. Make sure to add this function to the end of the `__call__` function of your pipeline so that it - functions correctly when applying enable_model_cpu_offload. - """ - if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0: - # `enable_model_cpu_offload` has not be called, so silently do nothing - return - - # make sure the model is in the same state as before calling it - self.enable_model_cpu_offload( - device=getattr(self, "_offload_device", "cuda"), - model_cpu_offload_seq=getattr(self, "_model_cpu_offload_seq_used", None), - ) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 6e614e9c9522..962cb5caa75f 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -17,7 +17,6 @@ import PIL import torch -from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from ...guider import CFGGuider from ...image_processor import VaeImageProcessor @@ -135,9 +134,6 @@ def inputs(self) -> List[Tuple[str, Any]]: def intermediates_outputs(self) -> List[str]: return ["batch_size"] - def __init__(self): - super().__init__() - @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: prompt = state.get_input("prompt") @@ -158,7 +154,6 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLTextEncoderStep(PipelineBlock): expected_components = ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"] expected_configs = ["force_zeros_for_empty_prompt"] - _model_cpu_offload_seq = "text_encoder->text_encoder_2" @property def inputs(self) -> List[Tuple[str, Any]]: @@ -186,21 +181,13 @@ def intermediates_outputs(self) -> List[str]: "negative_pooled_prompt_embeds", ] - def __init__( - self, - text_encoder: Optional[CLIPTextModel] = None, - text_encoder_2: Optional[CLIPTextModelWithProjection] = None, - tokenizer: Optional[CLIPTokenizer] = None, - tokenizer_2: Optional[CLIPTokenizer] = None, - force_zeros_for_empty_prompt: bool = True, - ): - super().__init__( - text_encoder=text_encoder, - text_encoder_2=text_encoder_2, - tokenizer=tokenizer, - tokenizer_2=tokenizer_2, - force_zeros_for_empty_prompt=force_zeros_for_empty_prompt, - ) + def __init__(self): + super().__init__() + self.configs["force_zeros_for_empty_prompt"] = True + self.components["text_encoder"] = None + self.components["text_encoder_2"] = None + self.components["tokenizer"] = None + self.components["tokenizer_2"] = None @staticmethod def check_inputs( @@ -327,7 +314,6 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLVAEEncoderStep(PipelineBlock): expected_components = ["vae"] - expected_auxiliaries = ["image_processor"] @property def inputs(self) -> List[Tuple[str, Any]]: @@ -348,10 +334,10 @@ def intermediates_inputs(self) -> List[str]: def intermediates_outputs(self) -> List[str]: return ["image_latents"] - def __init__(self, vae=None): - super().__init__(vae=vae) - self.image_processor = VaeImageProcessor() - self.auxiliaries["image_processor"] = self.image_processor + def __init__(self): + super().__init__() + self.components["vae"] = None + self.auxiliaries["image_processor"] = VaeImageProcessor() @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: @@ -445,8 +431,9 @@ def intermediates_inputs(self) -> List[str]: def intermediates_outputs(self) -> List[str]: return ["timesteps", "num_inference_steps", "latent_timestep"] - def __init__(self, scheduler=None): - super().__init__(scheduler=scheduler) + def __init__(self): + super().__init__() + self.components["scheduler"] = None @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: @@ -516,8 +503,9 @@ def inputs(self) -> List[Tuple[str, Any]]: def intermediates_outputs(self) -> List[str]: return ["timesteps", "num_inference_steps"] - def __init__(self, scheduler=None): - super().__init__(scheduler=scheduler) + def __init__(self): + super().__init__() + self.components["scheduler"] = None @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: @@ -552,7 +540,6 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): expected_components = ["vae", "scheduler"] - expected_auxiliaries = ["image_processor"] @property def inputs(self) -> List[Tuple[str, Any]]: @@ -576,10 +563,11 @@ def intermediates_inputs(self) -> List[str]: def intermediates_outputs(self) -> List[str]: return ["latents"] - def __init__(self, vae=None, scheduler=None): - super().__init__(vae=vae, scheduler=scheduler) - self.image_processor = VaeImageProcessor() - self.auxiliaries["image_processor"] = self.image_processor + def __init__(self): + super().__init__() + self.auxiliaries["image_processor"] = VaeImageProcessor() + self.components["vae"] = None + self.components["scheduler"] = None @torch.no_grad() def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: @@ -648,8 +636,9 @@ def intermediates_inputs(self) -> List[str]: def intermediates_outputs(self) -> List[str]: return ["latents"] - def __init__(self, scheduler=None): - super().__init__(scheduler=scheduler) + def __init__(self): + super().__init__() + self.components["scheduler"] = None @staticmethod def check_inputs(pipeline, height, width): @@ -732,8 +721,9 @@ def intermediates_inputs(self) -> List[str]: def intermediates_outputs(self) -> List[str]: return ["add_time_ids", "negative_add_time_ids", "timestep_cond"] - def __init__(self, requires_aesthetics_score=False): - super().__init__(requires_aesthetics_score=requires_aesthetics_score) + def __init__(self): + super().__init__() + self.configs["requires_aesthetics_score"] = False @torch.no_grad() def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: @@ -831,9 +821,6 @@ def intermediates_inputs(self) -> List[str]: def intermediates_outputs(self) -> List[str]: return ["add_time_ids", "negative_add_time_ids", "timestep_cond"] - def __init__(self): - super().__init__() - @torch.no_grad() def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: original_size = state.get_input("original_size") @@ -937,10 +924,11 @@ def intermediates_inputs(self) -> List[str]: def intermediates_outputs(self) -> List[str]: return ["latents"] - def __init__(self, unet=None, scheduler=None, guider=None): - if guider is None: - guider = CFGGuider() - super().__init__(unet=unet, scheduler=scheduler, guider=guider) + def __init__(self): + super().__init__() + self.components["guider"] = CFGGuider() + self.components["scheduler"] = None + self.components["unet"] = None @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: @@ -1041,8 +1029,6 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): expected_components = ["unet", "controlnet", "scheduler", "guider", "controlnet_guider"] - expected_auxiliaries = ["control_image_processor"] - _model_cpu_offload_seq = "unet" @property def inputs(self) -> List[Tuple[str, Any]]: @@ -1081,30 +1067,14 @@ def intermediates_inputs(self) -> List[str]: def intermediates_outputs(self) -> List[str]: return ["latents"] - def __init__( - self, - unet=None, - controlnet=None, - scheduler=None, - guider=None, - controlnet_guider=None, - vae_scale_factor=8.0, - ): - if guider is None: - guider = CFGGuider() - if controlnet_guider is None: - controlnet_guider = CFGGuider() - super().__init__( - unet=unet, - controlnet=controlnet, - scheduler=scheduler, - guider=guider, - controlnet_guider=controlnet_guider, - vae_scale_factor=vae_scale_factor, - ) - control_image_processor = VaeImageProcessor( - vae_scale_factor=vae_scale_factor, do_convert_rgb=True, do_normalize=False - ) + def __init__(self): + super().__init__() + self.components["guider"] = CFGGuider() + self.components["controlnet_guider"] = CFGGuider() + self.components["scheduler"] = None + self.components["unet"] = None + self.components["controlnet"] = None + control_image_processor = VaeImageProcessor(do_convert_rgb=True, do_normalize=False) self.auxiliaries["control_image_processor"] = control_image_processor @torch.no_grad() @@ -1330,7 +1300,6 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLDecodeLatentsStep(PipelineBlock): expected_components = ["vae"] - expected_auxiliaries = ["image_processor"] @property def inputs(self) -> List[Tuple[str, Any]]: @@ -1347,9 +1316,10 @@ def intermediates_inputs(self) -> List[str]: def intermediates_outputs(self) -> List[str]: return ["images"] - def __init__(self, vae=None, vae_scale_factor=8): - super().__init__(vae=vae, vae_scale_factor=vae_scale_factor) - self.auxiliaries["image_processor"] = VaeImageProcessor(vae_scale_factor=vae_scale_factor) + def __init__(self): + super().__init__() + self.components["vae"] = None + self.auxiliaries["image_processor"] = VaeImageProcessor(vae_scale_factor=8) @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: @@ -1468,9 +1438,6 @@ class StableDiffusionXLModularPipeline( TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, ): - def __init__(self): - super().__init__() - @property def default_sample_size(self): default_sample_size = 128 From ed59f90f1551f91b71461577d962921d2f157024 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 1 Jan 2025 22:15:48 +0100 Subject: [PATCH 032/170] modular pipeline builder -> ModularPipeline --- src/diffusers/__init__.py | 4 ++-- src/diffusers/pipelines/__init__.py | 4 ++-- .../{modular_pipeline_builder.py => modular_pipeline.py} | 6 +++--- src/diffusers/pipelines/pipeline_loading_utils.py | 2 +- .../pipeline_stable_diffusion_xl_modular.py | 6 +++--- src/diffusers/utils/dummy_pt_objects.py | 2 +- 6 files changed, 12 insertions(+), 12 deletions(-) rename src/diffusers/pipelines/{modular_pipeline_builder.py => modular_pipeline.py} (99%) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 1aa356184451..d2608b088f9c 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -162,7 +162,7 @@ "KarrasVePipeline", "LDMPipeline", "LDMSuperResolutionPipeline", - "ModularPipelineBuilder", + "ModularPipeline", "PNDMPipeline", "RePaintPipeline", "ScoreSdeVePipeline", @@ -674,7 +674,7 @@ KarrasVePipeline, LDMPipeline, LDMSuperResolutionPipeline, - ModularPipelineBuilder, + ModularPipeline, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 25bd2d8b7d59..c39e0a3c721a 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -46,7 +46,7 @@ "AutoPipelineForInpainting", "AutoPipelineForText2Image", ] - _import_structure["modular_pipeline_builder"] = ["ModularPipelineBuilder"] + _import_structure["modular_pipeline"] = ["ModularPipeline"] _import_structure["consistency_models"] = ["ConsistencyModelPipeline"] _import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"] _import_structure["ddim"] = ["DDIMPipeline"] @@ -452,7 +452,7 @@ from .deprecated import KarrasVePipeline, LDMPipeline, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline from .dit import DiTPipeline from .latent_diffusion import LDMSuperResolutionPipeline - from .modular_pipeline_builder import ModularPipelineBuilder + from .modular_pipeline import ModularPipeline from .pipeline_utils import ( AudioPipelineOutput, DiffusionPipeline, diff --git a/src/diffusers/pipelines/modular_pipeline_builder.py b/src/diffusers/pipelines/modular_pipeline.py similarity index 99% rename from src/diffusers/pipelines/modular_pipeline_builder.py rename to src/diffusers/pipelines/modular_pipeline.py index d91dbe6f1f81..7ac316a8dca7 100644 --- a/src/diffusers/pipelines/modular_pipeline_builder.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -187,7 +187,7 @@ def combine_inputs(*input_lists: List[Tuple[str, Any]]) -> List[Tuple[str, Any]] class MultiPipelineBlocks: """ A class that combines multiple pipeline block classes into one. When used, it has same API and properties as - PipelineBlock. And it can be used in ModularPipelineBuilder as a single pipeline block. + PipelineBlock. And it can be used in ModularPipeline as a single pipeline block. """ block_classes = [] @@ -583,7 +583,7 @@ def __repr__(self): ) -class ModularPipelineBuilder(ConfigMixin): +class ModularPipeline(ConfigMixin): """ Base class for all Modular pipelines. @@ -694,7 +694,7 @@ def set_progress_bar_config(self, **kwargs): self._progress_bar_config = kwargs def __call__(self, *args, **kwargs): - raise NotImplementedError("__call__ is not implemented for ModularPipelineBuilder") + raise NotImplementedError("__call__ is not implemented for ModularPipeline") def run_blocks(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs): """ diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 1c996dc9e3bf..87c614ff0eaf 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -376,7 +376,7 @@ def _get_pipeline_class( revision=revision, ) - if class_obj.__name__ != "DiffusionPipeline" and class_obj.__name__ != "ModularPipelineBuilder": + if class_obj.__name__ != "DiffusionPipeline" and class_obj.__name__ != "ModularPipeline": return class_obj diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0]) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 962cb5caa75f..54878cb4ce73 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -32,9 +32,9 @@ ) from ...utils.torch_utils import is_compiled_module, randn_tensor from ..controlnet.multicontrolnet import MultiControlNetModel -from ..modular_pipeline_builder import ( +from ..modular_pipeline import ( AutoPipelineBlocks, - ModularPipelineBuilder, + ModularPipeline, PipelineBlock, PipelineState, SequentialPipelineBlocks, @@ -1433,7 +1433,7 @@ class StableDiffusionXLAllSteps(SequentialPipelineBlocks): class StableDiffusionXLModularPipeline( - ModularPipelineBuilder, + ModularPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 05c0eb21ebb6..b41251ce2bd0 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1140,7 +1140,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class ModularPipelineBuilder(metaclass=DummyObject): +class ModularPipeline(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): From 72c5bf07c8d8a49d3f96e7bdcf5954bbaf1de608 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 2 Jan 2025 00:49:34 +0100 Subject: [PATCH 033/170] add a from_block class method to modular pipeline --- src/diffusers/pipelines/modular_pipeline.py | 21 ++++++++++++++++--- .../pipeline_stable_diffusion_xl_modular.py | 14 +++++++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index 7ac316a8dca7..6e563caea825 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -27,6 +27,7 @@ is_accelerate_version, logging, ) +from .pipeline_loading_utils import _get_pipeline_class if is_accelerate_available(): @@ -35,9 +36,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -MODULAR_PIPELINE_MAPPING = { - "stable-diffusion-xl": "StableDiffusionXLModularPipeline", -} +MODULAR_PIPELINE_MAPPING = OrderedDict( + [ + ("stable-diffusion-xl", "StableDiffusionXLModularPipeline"), + ] +) @dataclass @@ -103,6 +106,7 @@ class PipelineBlock: # pipelie block should set the default value for all expected config/components, so maybe we do not need to explicitly set the list expected_components = [] expected_configs = [] + model_name = None @property def inputs(self) -> Tuple[Tuple[str, Any], ...]: @@ -193,6 +197,10 @@ class MultiPipelineBlocks: block_classes = [] block_prefixes = [] + @property + def model_name(self): + return next(iter(self.blocks.values())).model_name + @property def expected_components(self): expected_components = [] @@ -606,6 +614,13 @@ def __init__(self, block): for key, value in block.auxiliaries.items(): setattr(self, key, value) + @classmethod + def from_block(cls, block): + modular_pipeline_class_name = MODULAR_PIPELINE_MAPPING[block.model_name] + modular_pipeline_class = _get_pipeline_class(cls, class_name=modular_pipeline_class_name) + + return modular_pipeline_class(block) + @property def device(self) -> torch.device: r""" diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 54878cb4ce73..c499e5e3b5be 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -123,6 +123,8 @@ def retrieve_latents( class StableDiffusionXLInputStep(PipelineBlock): + model_name = "stable-diffusion-xl" + @property def inputs(self) -> List[Tuple[str, Any]]: return [ @@ -154,6 +156,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLTextEncoderStep(PipelineBlock): expected_components = ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"] expected_configs = ["force_zeros_for_empty_prompt"] + model_name = "stable-diffusion-xl" @property def inputs(self) -> List[Tuple[str, Any]]: @@ -314,6 +317,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLVAEEncoderStep(PipelineBlock): expected_components = ["vae"] + model_name = "stable-diffusion-xl" @property def inputs(self) -> List[Tuple[str, Any]]: @@ -409,6 +413,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): expected_components = ["scheduler"] + model_name = "stable-diffusion-xl" @property def inputs(self) -> List[Tuple[str, Any]]: @@ -488,6 +493,7 @@ def denoising_value_valid(dnv): class StableDiffusionXLSetTimestepsStep(PipelineBlock): expected_components = ["scheduler"] + model_name = "stable-diffusion-xl" @property def inputs(self) -> List[Tuple[str, Any]]: @@ -540,6 +546,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): expected_components = ["vae", "scheduler"] + model_name = "stable-diffusion-xl" @property def inputs(self) -> List[Tuple[str, Any]]: @@ -615,6 +622,7 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin class StableDiffusionXLPrepareLatentsStep(PipelineBlock): expected_components = ["scheduler"] + model_name = "stable-diffusion-xl" @property def inputs(self) -> List[Tuple[str, Any]]: @@ -696,6 +704,7 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): expected_configs = ["requires_aesthetics_score"] + model_name = "stable-diffusion-xl" @property def inputs(self) -> List[Tuple[str, Any]]: @@ -799,6 +808,8 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): + model_name = "stable-diffusion-xl" + @property def inputs(self) -> List[Tuple[str, Any]]: return [ @@ -893,6 +904,7 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin class StableDiffusionXLDenoiseStep(PipelineBlock): expected_components = ["unet", "scheduler", "guider"] + model_name = "stable-diffusion-xl" @property def inputs(self) -> List[Tuple[str, Any]]: @@ -1029,6 +1041,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): expected_components = ["unet", "controlnet", "scheduler", "guider", "controlnet_guider"] + model_name = "stable-diffusion-xl" @property def inputs(self) -> List[Tuple[str, Any]]: @@ -1300,6 +1313,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLDecodeLatentsStep(PipelineBlock): expected_components = ["vae"] + model_name = "stable-diffusion-xl" @property def inputs(self) -> List[Tuple[str, Any]]: From 6c93626f6ff196f8718a352501b45013e9b2928f Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 2 Jan 2025 00:59:12 +0100 Subject: [PATCH 034/170] remove run_blocks, just use __call__ --- src/diffusers/pipelines/modular_pipeline.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index 6e563caea825..2e3e16e00615 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -708,10 +708,7 @@ def progress_bar(self, iterable=None, total=None): def set_progress_bar_config(self, **kwargs): self._progress_bar_config = kwargs - def __call__(self, *args, **kwargs): - raise NotImplementedError("__call__ is not implemented for ModularPipeline") - - def run_blocks(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs): + def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs): """ Run one or more blocks in sequence, optionally you can pass a previous pipeline state. """ From 1d6330629500cd9be1fab2a4aada4575a7b8206d Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 3 Jan 2025 06:07:25 +0100 Subject: [PATCH 035/170] make it work with lora --- src/diffusers/loaders/unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 7050968b6de5..bd0cf71e4c88 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -403,7 +403,7 @@ def _optionally_disable_offloading(cls, _pipeline): is_model_cpu_offload = False is_sequential_cpu_offload = False - if _pipeline is not None and _pipeline.hf_device_map is None: + if _pipeline is not None and hasattr(_pipeline,"hf_device_map") and _pipeline.hf_device_map is None: for _, component in _pipeline.components.items(): if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): if not is_model_cpu_offload: From 2e0f5c86cc7380bd4e939c2e3ebb28b199899f17 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 3 Jan 2025 18:20:39 +0100 Subject: [PATCH 036/170] start to add inpaint --- .../pipeline_stable_diffusion_xl_modular.py | 170 ++++++++++++++++++ 1 file changed, 170 insertions(+) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index c499e5e3b5be..0cba09491f53 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -544,6 +544,82 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: return pipeline, state +class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): + expected_components = ["vae", "scheduler"] + model_name = "stable-diffusion-xl" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + ("height", None), + ("width", None), + ("generator", None), + ("latents", None), + ("num_images_per_prompt", 1), + ("device", None), + ("dtype", None), + ("image", None), + ("denoising_start", None), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return ["batch_size", "latent_timestep", "prompt_embeds"] + + @property + def intermediates_outputs(self) -> List[str]: + return ["latents"] + + def __init__(self): + super().__init__() + self.auxiliaries["image_processor"] = VaeImageProcessor() + self.components["vae"] = None + self.components["scheduler"] = None + + @torch.no_grad() + def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: + latents = state.get_input("latents") + num_images_per_prompt = state.get_input("num_images_per_prompt") + generator = state.get_input("generator") + device = state.get_input("device") + dtype = state.get_input("dtype") + + # image to image only + image = state.get_input("image") + denoising_start = state.get_input("denoising_start") + + batch_size = state.get_intermediate("batch_size") + prompt_embeds = state.get_intermediate("prompt_embeds") + # image to image only + latent_timestep = state.get_intermediate("latent_timestep") + + if dtype is None and prompt_embeds is not None: + dtype = prompt_embeds.dtype + elif dtype is None: + dtype = pipeline.vae.dtype + + if device is None: + device = pipeline._execution_device + + image = pipeline.image_processor.preprocess(image) + add_noise = True if denoising_start is None else False + if latents is None: + latents = pipeline.prepare_latents_img2img( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + dtype, + device, + generator, + add_noise, + ) + + state.add_intermediate("latents", latents) + + return pipeline, state + + class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): expected_components = ["vae", "scheduler"] model_name = "stable-diffusion-xl" @@ -2026,6 +2102,100 @@ def prepare_latents_img2img( return latents + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents + def prepare_latents_inpaint( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + add_noise=True, + return_noise=False, + return_image_latents=False, + ): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if image.shape[1] == 4: + image_latents = image.to(device=device, dtype=dtype) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + elif return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + + if latents is None and add_noise: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + elif add_noise: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + else: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = image_latents.to(device) + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + dtype = image.dtype + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + if self.vae.config.force_upcast: + self.vae.to(dtype) + + image_latents = image_latents.to(dtype) + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature From c12a05b9c19851912090418b6cf7f5a51358b8b2 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 3 Jan 2025 20:57:44 +0100 Subject: [PATCH 037/170] update to to not assume pipeline has hf_device_map --- src/diffusers/pipelines/modular_pipeline.py | 2 +- src/diffusers/pipelines/pipeline_utils.py | 6 +- .../pipeline_stable_diffusion_xl_modular.py | 156 +++++++++--------- 3 files changed, 86 insertions(+), 78 deletions(-) diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index 2e3e16e00615..711b9a91de2d 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -966,7 +966,7 @@ def module_is_offloaded(module): "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading." ) - is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1 + is_pipeline_device_mapped = hasattr(self, "hf_device_map") and self.hf_device_map is not None and len(self.hf_device_map) > 1 if is_pipeline_device_mapped: raise ValueError( "It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` first and then call `to()`." diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index c505c5a262a3..2f7894a0c0d6 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -422,7 +422,7 @@ def module_is_offloaded(module): "You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation." ) - is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1 + is_pipeline_device_mapped = hasattr(self, "hf_device_map") and self.hf_device_map is not None and len(self.hf_device_map) > 1 if is_pipeline_device_mapped: raise ValueError( "It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` first and then call `to()`." @@ -1030,7 +1030,7 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will default to "cuda". """ - is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1 + is_pipeline_device_mapped = hasattr(self, "hf_device_map") and self.hf_device_map is not None and len(self.hf_device_map) > 1 if is_pipeline_device_mapped: raise ValueError( "It seems like you have activated a device mapping strategy on the pipeline so calling `enable_model_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_model_cpu_offload()`." @@ -1138,7 +1138,7 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") self.remove_all_hooks() - is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1 + is_pipeline_device_mapped = hasattr(self, "hf_device_map") and self.hf_device_map is not None and len(self.hf_device_map) > 1 if is_pipeline_device_mapped: raise ValueError( "It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`." diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 0cba09491f53..55ded67743aa 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -544,80 +544,88 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: return pipeline, state -class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): - expected_components = ["vae", "scheduler"] - model_name = "stable-diffusion-xl" - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - ("height", None), - ("width", None), - ("generator", None), - ("latents", None), - ("num_images_per_prompt", 1), - ("device", None), - ("dtype", None), - ("image", None), - ("denoising_start", None), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return ["batch_size", "latent_timestep", "prompt_embeds"] - - @property - def intermediates_outputs(self) -> List[str]: - return ["latents"] - - def __init__(self): - super().__init__() - self.auxiliaries["image_processor"] = VaeImageProcessor() - self.components["vae"] = None - self.components["scheduler"] = None - - @torch.no_grad() - def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - latents = state.get_input("latents") - num_images_per_prompt = state.get_input("num_images_per_prompt") - generator = state.get_input("generator") - device = state.get_input("device") - dtype = state.get_input("dtype") - - # image to image only - image = state.get_input("image") - denoising_start = state.get_input("denoising_start") - - batch_size = state.get_intermediate("batch_size") - prompt_embeds = state.get_intermediate("prompt_embeds") - # image to image only - latent_timestep = state.get_intermediate("latent_timestep") - - if dtype is None and prompt_embeds is not None: - dtype = prompt_embeds.dtype - elif dtype is None: - dtype = pipeline.vae.dtype - - if device is None: - device = pipeline._execution_device - - image = pipeline.image_processor.preprocess(image) - add_noise = True if denoising_start is None else False - if latents is None: - latents = pipeline.prepare_latents_img2img( - image, - latent_timestep, - batch_size, - num_images_per_prompt, - dtype, - device, - generator, - add_noise, - ) - - state.add_intermediate("latents", latents) - - return pipeline, state +# class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): +# expected_components = ["vae", "scheduler"] +# model_name = "stable-diffusion-xl" + +# @property +# def inputs(self) -> List[Tuple[str, Any]]: +# return [ +# ("height", None), +# ("width", None), +# ("generator", None), +# ("latents", None), +# ("num_images_per_prompt", 1), +# ("device", None), +# ("dtype", None), +# ("image", None), +# ("denoising_start", None), +# ] + +# @property +# def intermediates_inputs(self) -> List[str]: +# return ["batch_size", "latent_timestep", "prompt_embeds"] + +# @property +# def intermediates_outputs(self) -> List[str]: +# return ["latents"] + +# def __init__(self): +# super().__init__() +# self.auxiliaries["image_processor"] = VaeImageProcessor() +# self.components["vae"] = None +# self.components["scheduler"] = None + +# @torch.no_grad() +# def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: +# latents = state.get_input("latents") +# num_images_per_prompt = state.get_input("num_images_per_prompt") +# generator = state.get_input("generator") +# device = state.get_input("device") +# dtype = state.get_input("dtype") + +# # image to image only +# image = state.get_input("image") +# denoising_start = state.get_input("denoising_start") + +# # inpaint only +# strength = state.get_input("strength") +# padding_mask_crop = state.get_input("padding_mask_crop") +# mask_image = state.get_input("mask_image") +# masked_image_latents = state.get_input("masked_image_latents") + + + +# batch_size = state.get_intermediate("batch_size") +# prompt_embeds = state.get_intermediate("prompt_embeds") +# # image to image only +# latent_timestep = state.get_intermediate("latent_timestep") + +# if dtype is None and prompt_embeds is not None: +# dtype = prompt_embeds.dtype +# elif dtype is None: +# dtype = pipeline.vae.dtype + +# if device is None: +# device = pipeline._execution_device + +# image = pipeline.image_processor.preprocess(image) +# add_noise = True if denoising_start is None else False +# if latents is None: +# latents = pipeline.prepare_latents_img2img( +# image, +# latent_timestep, +# batch_size, +# num_images_per_prompt, +# dtype, +# device, +# generator, +# add_noise, +# ) + +# state.add_intermediate("latents", latents) + +# return pipeline, state class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): From 54f410db6cfcf6a45b7a08793a986d02ddb69de2 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 6 Jan 2025 09:19:59 +0100 Subject: [PATCH 038/170] add inpaint --- src/diffusers/pipelines/modular_pipeline.py | 101 +-- .../pipeline_stable_diffusion_xl_modular.py | 594 +++++++++++++----- 2 files changed, 456 insertions(+), 239 deletions(-) diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index 711b9a91de2d..7964bda2ceda 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -166,25 +166,34 @@ def __repr__(self): ) -def combine_inputs(*input_lists: List[Tuple[str, Any]]) -> List[Tuple[str, Any]]: +def combine_inputs(*named_input_lists: List[Tuple[str, List[Tuple[str, Any]]]]) -> List[Tuple[str, Any]]: """ - Combines multiple lists of (name, default_value) tuples. For duplicate inputs, updates only if current value is - None and new value is not None. Warns if multiple non-None default values exist for the same input. + Combines multiple lists of (name, default_value) tuples from different blocks. For duplicate inputs, updates only if + current value is None and new value is not None. Warns if multiple non-None default values exist for the same input. + + Args: + named_input_lists: List of tuples containing (block_name, input_list) pairs """ combined_dict = {} - for inputs in input_lists: + # Track which block provided which value + value_sources = {} + + for block_name, inputs in named_input_lists: for name, value in inputs: if name in combined_dict: current_value = combined_dict[name] if current_value is not None and value is not None and current_value != value: warnings.warn( f"Multiple different default values found for input '{name}': " - f"{current_value} and {value}. Using {current_value}." + f"{current_value} (from block '{value_sources[name]}') and " + f"{value} (from block '{block_name}'). Using {current_value}." ) if current_value is None and value is not None: combined_dict[name] = value + value_sources[name] = block_name else: combined_dict[name] = value + value_sources[name] = block_name return list(combined_dict.items()) @@ -268,62 +277,10 @@ def intermediates_inputs(self) -> List[str]: def intermediates_outputs(self) -> List[str]: raise NotImplementedError("intermediates_outputs property must be implemented in subclasses") - @property - def model_cpu_offload_seq(self): - raise NotImplementedError("model_cpu_offload_seq property must be implemented in subclasses") - def __call__(self, pipeline, state): raise NotImplementedError("__call__ method must be implemented in subclasses") - def __repr__(self): - class_name = self.__class__.__name__ - - # Components section - expected_components = set(getattr(self, "expected_components", [])) - loaded_components = set(self.components.keys()) - all_components = sorted(expected_components | loaded_components) - components_str = " Components:\n" + "\n".join( - f" - {k}={type(self.components[k]).__name__}" if k in loaded_components else f" - {k}" - for k in all_components - ) - - # Auxiliaries section - auxiliaries_str = " Auxiliaries:\n" + "\n".join( - f" - {k}={type(v).__name__}" for k, v in self.auxiliaries.items() - ) - - # Configs section - expected_configs = set(getattr(self, "expected_configs", [])) - loaded_configs = set(self.configs.keys()) - all_configs = sorted(expected_configs | loaded_configs) - configs_str = " Configs:\n" + "\n".join( - f" - {k}={self.configs[k]}" if k in loaded_configs else f" - {k}" for k in all_configs - ) - - # Blocks section - blocks_str = " Blocks:\n" + "\n".join( - f" - {name}={block.__class__.__name__}" for name, block in self.blocks.items() - ) - - # Other information - inputs_str = " Inputs:\n" + "\n".join(f" - {name}={default}" for name, default in self.inputs) - - intermediates_str = ( - " Intermediates:\n" - f" - inputs: {', '.join(self.intermediates_inputs)}\n" - f" - outputs: {', '.join(self.intermediates_outputs)}" - ) - return ( - f"{class_name}(\n" - f"{components_str}\n" - f"{auxiliaries_str}\n" - f"{configs_str}\n" - f"{blocks_str}\n" - f"{inputs_str}\n" - f"{intermediates_str}\n" - f")" - ) # YiYi TODO: remove the trigger input logic and keep it more flexible and less convenient: @@ -364,7 +321,8 @@ def __post_init__(self): @property def inputs(self) -> List[Tuple[str, Any]]: - return combine_inputs(*(block.inputs for block in self.blocks.values())) + named_inputs = [(name, block.inputs) for name, block in self.blocks.items()] + return combine_inputs(*named_inputs) @property def intermediates_inputs(self) -> List[str]: @@ -489,7 +447,8 @@ class SequentialPipelineBlocks(MultiPipelineBlocks): @property def inputs(self) -> List[Tuple[str, Any]]: - return combine_inputs(*(block.inputs for block in self.blocks.values())) + named_inputs = [(name, block.inputs) for name, block in self.blocks.items()] + return combine_inputs(*named_inputs) @property def intermediates_inputs(self) -> List[str]: @@ -822,21 +781,19 @@ def __repr__(self): output += f"{block.__class__.__name__}\n" output += "\n" - intermediates_str = "" - if hasattr(block, "intermediates_inputs"): - intermediates_str += f"{', '.join(block.intermediates_inputs)}" - if hasattr(block, "intermediates_outputs"): - if intermediates_str: - intermediates_str += " -> " - else: - intermediates_str += "-> " - intermediates_str += f"{', '.join(block.intermediates_outputs)}" - - if intermediates_str: + intermediates_str = f"-> {', '.join(block.intermediates_outputs)}" output += f" {intermediates_str}\n" + output += "\n" - output += "\n" + # Add final intermediate outputs for SequentialPipelineBlocks + if isinstance(block, SequentialPipelineBlocks): + last_block = list(block.blocks.values())[-1] + if hasattr(last_block, "intermediates_outputs"): + final_outputs = last_block.intermediates_outputs + final_intermediates_str = f" (final intermediate outputs: {', '.join(final_outputs)})" + output += f" {final_intermediates_str}\n" + output += "\n" # List the components registered in the pipeline output += "Registered Components:\n" @@ -861,7 +818,7 @@ def __repr__(self): for name, default in self.default_call_parameters.items(): output += f"{name}: {default!r}\n" - output += "\nRequired intermediate inputs:\n" + output += "\nIntermediate inputs:\n" output += "--------------------------\n" for name in self.pipeline_block.intermediates_inputs: output += f"{name}: \n" diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 55ded67743aa..1798a11fb0f4 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -122,6 +122,65 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") +class StableDiffusionXLOutputStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [("return_dict", True)] + + @property + def intermediates_outputs(self) -> List[str]: + return ["images"] + + @torch.no_grad() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + images = state.get_intermediate("images") + return_dict = state.get_input("return_dict") + + if not return_dict: + output = (images,) + else: + output = StableDiffusionXLPipelineOutput(images=images) + state.add_output("images", output) + return pipeline, state + + +class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + ("image", None), + ("mask_image", None), + ("padding_mask_crop", None), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return ["crops_coords", "images"] + + @property + def intermediates_outputs(self) -> List[str]: + return ["images"] + + @torch.no_grad() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + original_image = state.get_input("image") + padding_mask_crop = state.get_input("padding_mask_crop") + mask_image = state.get_input("mask_image") + images = state.get_intermediate("images") + crops_coords = state.get_intermediate("crops_coords") + + if padding_mask_crop is not None and crops_coords is not None: + images = [pipeline.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in images] + + state.add_intermediate("images", images) + + return pipeline, state + + class StableDiffusionXLInputStep(PipelineBlock): model_name = "stable-diffusion-xl" @@ -182,6 +241,7 @@ def intermediates_outputs(self) -> List[str]: "negative_prompt_embeds", "pooled_prompt_embeds", "negative_pooled_prompt_embeds", + "dtype", ] def __init__(self): @@ -312,6 +372,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: state.add_intermediate("negative_prompt_embeds", negative_prompt_embeds) state.add_intermediate("pooled_prompt_embeds", pooled_prompt_embeds) state.add_intermediate("negative_pooled_prompt_embeds", negative_pooled_prompt_embeds) + state.add_intermediate("dtype", prompt_embeds.dtype) return pipeline, state @@ -326,13 +387,12 @@ def inputs(self) -> List[Tuple[str, Any]]: ("generator", None), ("height", None), ("width", None), - ("device", None), - ("dtype", None), + ("num_images_per_prompt", 1), ] @property def intermediates_inputs(self) -> List[str]: - return ["batch_size"] + return ["batch_size", "dtype","preprocess_kwargs"] @property def intermediates_outputs(self) -> List[str]: @@ -345,34 +405,31 @@ def __init__(self): @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: - image = state.get_input("image") + num_images_per_prompt = state.get_input("num_images_per_prompt") generator = state.get_input("generator") + height = state.get_input("height") width = state.get_input("width") - device = state.get_input("device") - dtype = state.get_input("dtype") + image = state.get_input("image") + preprocess_kwargs = state.get_intermediate("preprocess_kwargs") or {} batch_size = state.get_intermediate("batch_size") + dtype = state.get_intermediate("dtype") - if device is None: - device = pipeline._execution_device + device = pipeline._execution_device if dtype is None: dtype = pipeline.vae.dtype + - image = pipeline.image_processor.preprocess(image, height=height, width=width) + image = pipeline.image_processor.preprocess(image, height=height, width=width, **preprocess_kwargs) image = image.to(device=device, dtype=dtype) - latents_mean = latents_std = None - if hasattr(pipeline.vae.config, "latents_mean") and pipeline.vae.config.latents_mean is not None: - latents_mean = torch.tensor(pipeline.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(pipeline.vae.config, "latents_std") and pipeline.vae.config.latents_std is not None: - latents_std = torch.tensor(pipeline.vae.config.latents_std).view(1, 4, 1, 1) - - # make sure the VAE is in float32 mode, as it overflows in float16 - if pipeline.vae.config.force_upcast: - image = image.float() - pipeline.vae.to(dtype=torch.float32) + if batch_size is None: + batch_size = image.shape[0] + + batch_size = batch_size * num_images_per_prompt + # if generator is a list, make sure the length of it matches the length of images (both should be batch_size) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -387,26 +444,20 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " ) - init_latents = [ - retrieve_latents(pipeline.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(batch_size) - ] - init_latents = torch.cat(init_latents, dim=0) - else: - init_latents = retrieve_latents(pipeline.vae.encode(image), generator=generator) - - if pipeline.vae.config.force_upcast: - pipeline.vae.to(dtype) - - init_latents = init_latents.to(dtype) - if latents_mean is not None and latents_std is not None: - latents_mean = latents_mean.to(device=device, dtype=dtype) - latents_std = latents_std.to(device=device, dtype=dtype) - init_latents = (init_latents - latents_mean) * pipeline.vae.config.scaling_factor / latents_std + image_latents = pipeline._encode_vae_image(image=image, generator=generator) + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) else: - init_latents = pipeline.vae.config.scaling_factor * init_latents - - state.add_intermediate("image_latents", init_latents) + image_latents = torch.cat([image_latents], dim=0) + + state.add_intermediate("image_latents", image_latents) return pipeline, state @@ -425,7 +476,6 @@ def inputs(self) -> List[Tuple[str, Any]]: ("strength", 0.3), ("denoising_start", None), ("num_images_per_prompt", 1), - ("device", None), ] @property @@ -446,7 +496,6 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: timesteps = state.get_input("timesteps") sigmas = state.get_input("sigmas") denoising_end = state.get_input("denoising_end") - device = state.get_input("device") # image to image only strength = state.get_input("strength") @@ -456,8 +505,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: # image to image only batch_size = state.get_intermediate("batch_size") - if device is None: - device = pipeline._execution_device + device = pipeline._execution_device timesteps, num_inference_steps = retrieve_timesteps( pipeline.scheduler, num_inference_steps, device, timesteps, sigmas @@ -502,7 +550,6 @@ def inputs(self) -> List[Tuple[str, Any]]: ("timesteps", None), ("sigmas", None), ("denoising_end", None), - ("device", None), ] @property @@ -519,10 +566,8 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: timesteps = state.get_input("timesteps") sigmas = state.get_input("sigmas") denoising_end = state.get_input("denoising_end") - device = state.get_input("device") - if device is None: - device = pipeline._execution_device + device = pipeline._execution_device timesteps, num_inference_steps = retrieve_timesteps( pipeline.scheduler, num_inference_steps, device, timesteps, sigmas @@ -544,88 +589,212 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: return pipeline, state -# class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): -# expected_components = ["vae", "scheduler"] -# model_name = "stable-diffusion-xl" - -# @property -# def inputs(self) -> List[Tuple[str, Any]]: -# return [ -# ("height", None), -# ("width", None), -# ("generator", None), -# ("latents", None), -# ("num_images_per_prompt", 1), -# ("device", None), -# ("dtype", None), -# ("image", None), -# ("denoising_start", None), -# ] - -# @property -# def intermediates_inputs(self) -> List[str]: -# return ["batch_size", "latent_timestep", "prompt_embeds"] - -# @property -# def intermediates_outputs(self) -> List[str]: -# return ["latents"] - -# def __init__(self): -# super().__init__() -# self.auxiliaries["image_processor"] = VaeImageProcessor() -# self.components["vae"] = None -# self.components["scheduler"] = None - -# @torch.no_grad() -# def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: -# latents = state.get_input("latents") -# num_images_per_prompt = state.get_input("num_images_per_prompt") -# generator = state.get_input("generator") -# device = state.get_input("device") -# dtype = state.get_input("dtype") - -# # image to image only -# image = state.get_input("image") -# denoising_start = state.get_input("denoising_start") - -# # inpaint only -# strength = state.get_input("strength") -# padding_mask_crop = state.get_input("padding_mask_crop") -# mask_image = state.get_input("mask_image") -# masked_image_latents = state.get_input("masked_image_latents") - - - -# batch_size = state.get_intermediate("batch_size") -# prompt_embeds = state.get_intermediate("prompt_embeds") -# # image to image only -# latent_timestep = state.get_intermediate("latent_timestep") - -# if dtype is None and prompt_embeds is not None: -# dtype = prompt_embeds.dtype -# elif dtype is None: -# dtype = pipeline.vae.dtype - -# if device is None: -# device = pipeline._execution_device - -# image = pipeline.image_processor.preprocess(image) -# add_noise = True if denoising_start is None else False -# if latents is None: -# latents = pipeline.prepare_latents_img2img( -# image, -# latent_timestep, -# batch_size, -# num_images_per_prompt, -# dtype, -# device, -# generator, -# add_noise, -# ) - -# state.add_intermediate("latents", latents) - -# return pipeline, state +class StableDiffusionXLInpaintVaeEncodeStep(PipelineBlock): + expected_components = ["vae"] + model_name = "stable-diffusion-xl" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + ("height", None), + ("width", None), + ("generator", None), + ("num_images_per_prompt", 1), + ("image", None), + ("mask_image", None), + ("padding_mask_crop", None), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return ["batch_size", "dtype"] + + @property + def intermediates_outputs(self) -> List[str]: + return ["image_latents", "mask", "masked_image_latents", "crops_coords"] + + def __init__(self): + super().__init__() + self.auxiliaries["image_processor"] = VaeImageProcessor() + self.auxiliaries["mask_processor"] = VaeImageProcessor(do_normalize=False, do_binarize=True, do_convert_grayscale=True) + self.components["vae"] = None + + @torch.no_grad() + def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: + + num_images_per_prompt = state.get_input("num_images_per_prompt") + # YiYi TODO: we don't put generator back to state but it actually gets used and updated + # it is ok but think about how we can handle mutable inputs better in PipelineState so user would be aware + generator = state.get_input("generator") + + height = state.get_input("height") + width = state.get_input("width") + # inpaint only + image = state.get_input("image") + padding_mask_crop = state.get_input("padding_mask_crop") + mask_image = state.get_input("mask_image") + + batch_size = state.get_intermediate("batch_size") + dtype = state.get_intermediate("dtype") + + if dtype is None: + dtype = pipeline.vae.dtype + device = pipeline._execution_device + + if padding_mask_crop is not None: + crops_coords = pipeline.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + image = pipeline.image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode) + image = image.to(dtype=torch.float32) + + mask = pipeline.mask_processor.preprocess(mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords) + masked_image = image * (mask < 0.5) + + if batch_size is None: + batch_size = image.shape[0] + + batch_size = batch_size * num_images_per_prompt + image = image.to(device=device, dtype=dtype) + image_latents = pipeline._encode_vae_image(image=image, generator=generator) + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + + # 7. Prepare mask latent variables + mask, masked_image_latents = pipeline.prepare_mask_latents( + mask, + masked_image, + batch_size, + height, + width, + dtype, + device, + generator, + ) + + state.add_intermediate("mask", mask) + state.add_intermediate("masked_image_latents", masked_image_latents) + state.add_intermediate("image_latents", image_latents) + state.add_intermediate("crops_coords", crops_coords) + + + return pipeline, state + + +# inpaint-specific +class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): + expected_components = ["scheduler"] + model_name = "stable-diffusion-xl" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + ("generator", None), + ("latents", None), + ("num_images_per_prompt", 1), + ("denoising_start", None), + ("strength", 0.9999), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return ["batch_size", "dtype", "latent_timestep", "image_latents", "mask", "masked_image_latents"] + + @property + def intermediates_outputs(self) -> List[str]: + return ["latents", "mask", "masked_image_latents", "noise"] + + def __init__(self): + super().__init__() + self.components["scheduler"] = None + + @torch.no_grad() + def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: + latents = state.get_input("latents") + num_images_per_prompt = state.get_input("num_images_per_prompt") + generator = state.get_input("generator") + # image to image only + denoising_start = state.get_input("denoising_start") + + # inpaint only + strength = state.get_input("strength") + + # image to image only + latent_timestep = state.get_intermediate("latent_timestep") + + # YiYi Notes: mask and masked_image_latents should be intermediate outputs from StableDiffusionXLPrepareMaskedImageLatentsStep + image_latents = state.get_intermediate("image_latents") + mask = state.get_intermediate("mask") + masked_image_latents = state.get_intermediate("masked_image_latents") + + + batch_size = state.get_intermediate("batch_size") + dtype = state.get_intermediate("dtype") + + if dtype is None: + dtype = pipeline.vae.dtype + device = pipeline._execution_device + + is_strength_max = strength == 1.0 + + # for non-inpainting specific unet, we do not need masked_image_latents + if hasattr(pipeline,"unet") and pipeline.unet is not None: + if pipeline.unet.config.in_channels == 4: + masked_image_latents = None + + add_noise = True if denoising_start is None else False + + height = image_latents.shape[-2] * pipeline.vae_scale_factor + width = image_latents.shape[-1] * pipeline.vae_scale_factor + + latents, noise = pipeline.prepare_latents_inpaint( + batch_size * num_images_per_prompt, + pipeline.num_channels_latents, + height, + width, + dtype, + device, + generator, + latents, + image=image_latents, + timestep=latent_timestep, + is_strength_max=is_strength_max, + add_noise=add_noise, + return_noise=True, + return_image_latents=False, + ) + + # 7. Prepare mask latent variables + mask, masked_image_latents = pipeline.prepare_mask_latents( + mask, + masked_image_latents, + batch_size * num_images_per_prompt, + height, + width, + dtype, + device, + generator, + ) + + state.add_intermediate("latents", latents) + state.add_intermediate("mask", mask) + state.add_intermediate("masked_image_latents", masked_image_latents) + state.add_intermediate("noise", noise) + + return pipeline, state class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): @@ -640,15 +809,13 @@ def inputs(self) -> List[Tuple[str, Any]]: ("generator", None), ("latents", None), ("num_images_per_prompt", 1), - ("device", None), - ("dtype", None), ("image", None), ("denoising_start", None), ] @property def intermediates_inputs(self) -> List[str]: - return ["batch_size", "latent_timestep", "prompt_embeds"] + return ["batch_size", "dtype", "latent_timestep"] @property def intermediates_outputs(self) -> List[str]: @@ -665,25 +832,20 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin latents = state.get_input("latents") num_images_per_prompt = state.get_input("num_images_per_prompt") generator = state.get_input("generator") - device = state.get_input("device") - dtype = state.get_input("dtype") # image to image only image = state.get_input("image") denoising_start = state.get_input("denoising_start") batch_size = state.get_intermediate("batch_size") - prompt_embeds = state.get_intermediate("prompt_embeds") + dtype = state.get_intermediate("dtype") # image to image only latent_timestep = state.get_intermediate("latent_timestep") - if dtype is None and prompt_embeds is not None: - dtype = prompt_embeds.dtype - elif dtype is None: + if dtype is None: dtype = pipeline.vae.dtype - if device is None: - device = pipeline._execution_device + device = pipeline._execution_device image = pipeline.image_processor.preprocess(image) add_noise = True if denoising_start is None else False @@ -716,13 +878,11 @@ def inputs(self) -> List[Tuple[str, Any]]: ("generator", None), ("latents", None), ("num_images_per_prompt", 1), - ("device", None), - ("dtype", None), ] @property def intermediates_inputs(self) -> List[str]: - return ["batch_size", "prompt_embeds"] + return ["batch_size", "dtype"] @property def intermediates_outputs(self) -> List[str]: @@ -749,21 +909,17 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin latents = state.get_input("latents") num_images_per_prompt = state.get_input("num_images_per_prompt") generator = state.get_input("generator") - device = state.get_input("device") - dtype = state.get_input("dtype") # text to image only height = state.get_input("height") width = state.get_input("width") batch_size = state.get_intermediate("batch_size") - prompt_embeds = state.get_intermediate("prompt_embeds") - + dtype = state.get_intermediate("dtype") if dtype is None: - dtype = prompt_embeds.dtype + dtype = pipeline.vae.dtype - if device is None: - device = pipeline._execution_device + device = pipeline._execution_device self.check_inputs(pipeline, height, width) @@ -803,7 +959,6 @@ def inputs(self) -> List[Tuple[str, Any]]: ("guidance_scale", 5.0), ("aesthetic_score", 6.0), ("negative_aesthetic_score", 2.0), - ("device", None), ] @property @@ -828,7 +983,6 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin negative_crops_coords_top_left = state.get_input("negative_crops_coords_top_left") num_images_per_prompt = state.get_input("num_images_per_prompt") guidance_scale = state.get_input("guidance_scale") - device = state.get_input("device") # image to image only aesthetic_score = state.get_input("aesthetic_score") @@ -838,12 +992,16 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin batch_size = state.get_intermediate("batch_size") pooled_prompt_embeds = state.get_intermediate("pooled_prompt_embeds") - if device is None: - device = pipeline._execution_device + device = pipeline._execution_device + + if hasattr(pipeline, "vae") and pipeline.vae is not None: + vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1) + else: + vae_scale_factor = 8 height, width = latents.shape[-2:] - height = height * pipeline.vae_scale_factor - width = width * pipeline.vae_scale_factor + height = height * vae_scale_factor + width = width * vae_scale_factor original_size = original_size or (height, width) target_size = target_size or (height, width) @@ -905,7 +1063,6 @@ def inputs(self) -> List[Tuple[str, Any]]: ("negative_crops_coords_top_left", (0, 0)), ("num_images_per_prompt", 1), ("guidance_scale", 5.0), - ("device", None), ] @property @@ -932,8 +1089,7 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin batch_size = state.get_intermediate("batch_size") pooled_prompt_embeds = state.get_intermediate("pooled_prompt_embeds") - if device is None: - device = pipeline._execution_device + device = pipeline._execution_device height, width = latents.shape[-2:] height = height * pipeline.vae_scale_factor @@ -1014,6 +1170,9 @@ def intermediates_inputs(self) -> List[str]: "timestep_cond", "prompt_embeds", "negative_prompt_embeds", + "mask", # inpainting + "masked_image_latents", # inpainting + "noise", # inpainting ] @property @@ -1047,6 +1206,29 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: timestep_cond = state.get_intermediate("timestep_cond") latents = state.get_intermediate("latents") + # inpainting + mask = state.get_intermediate("mask") + masked_image_latents = state.get_intermediate("masked_image_latents") + noise = state.get_intermediate("noise") + image_latents = state.get_intermediate("image_latents") + + num_channels_unet = pipeline.unet.config.in_channels + if num_channels_unet == 9: + # default case for runwayml/stable-diffusion-inpainting + if mask is None or masked_image_latents is None: + raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") + num_channels_latents = latents.shape[1] + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {pipeline.unet.config} expects" + f" {pipeline.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + timesteps = state.get_intermediate("timesteps") num_inference_steps = state.get_intermediate("num_inference_steps") disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False @@ -1076,6 +1258,10 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: negative_pooled_prompt_embeds, ) + if num_channels_unet == 9: + mask = pipeline.guider.prepare_input(mask, mask) + masked_image_latents = pipeline.guider.prepare_input(masked_image_latents, masked_image_latents) + added_cond_kwargs = { "text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids, @@ -1090,6 +1276,11 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: # expand the latents if we are doing classifier free guidance latent_model_input = pipeline.guider.prepare_input(latents, latents) latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t) + + # inpainting + if num_channels_unet == 9: + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + # predict the noise residual noise_pred = pipeline.unet( latent_model_input, @@ -1113,6 +1304,17 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) + + if num_channels_unet == 4 and mask is not None and image_latents is not None: + init_mask = pipeline.guider._maybe_split_prepared_input(mask)[0] + init_latents_proper = image_latents + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = pipeline.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): progress_bar.update() @@ -1470,14 +1672,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: image = pipeline.watermark.apply_watermark(image) image = pipeline.image_processor.postprocess(image, output_type=output_type) - - if not return_dict: - output = (image,) - else: - output = StableDiffusionXLPipelineOutput(images=image) - state.add_intermediate("images", image) - state.add_output("images", output) return pipeline, state @@ -1518,6 +1713,7 @@ class StableDiffusionXLAllSteps(SequentialPipelineBlocks): StableDiffusionXLAutoPrepareAdditionalConditioningStep, StableDiffusionXLAutoDenoiseStep, StableDiffusionXLDecodeLatentsStep, + StableDiffusionXLOutputStep ] block_prefixes = [ "input", @@ -1527,6 +1723,7 @@ class StableDiffusionXLAllSteps(SequentialPipelineBlocks): "prepare_add_cond", "denoise", "decode_latents", + "output" ] @@ -2025,6 +2222,7 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype return latents # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents + # YiYi TODO: refactor using _encode_vae_image def prepare_latents_img2img( self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True ): @@ -2178,8 +2376,16 @@ def prepare_latents_inpaint( return outputs - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image + # YiYi TODO: update the _encode_vae_image so that we can use #Coped from def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + dtype = image.dtype if self.vae.config.force_upcast: image = image.float() @@ -2198,10 +2404,64 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): self.vae.to(dtype) image_latents = image_latents.to(dtype) - image_latents = self.vae.config.scaling_factor * image_latents + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) + latents_std = latents_std.to(device=image_latents.device, dtype=dtype) + image_latents = (image_latents - latents_mean) * self.vae.config.scaling_factor / latents_std + else: + image_latents = self.vae.config.scaling_factor * image_latents return image_latents + + # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents + # do not accept do_classifier_free_guidance + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + + if masked_image is not None and masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + masked_image_latents = None + + if masked_image is not None: + if masked_image_latents is None: + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat( + batch_size // masked_image_latents.shape[0], 1, 1, 1 + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + return mask, masked_image_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs From 6985906a2ee757ab31b8292a6321fe026a349f6d Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 7 Jan 2025 01:56:33 +0100 Subject: [PATCH 039/170] controlnet input & remove the MultiPipelineBlocks class --- src/diffusers/pipelines/modular_pipeline.py | 204 +++++++++--------- .../pipeline_stable_diffusion_xl_modular.py | 93 ++++++-- 2 files changed, 182 insertions(+), 115 deletions(-) diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index 7964bda2ceda..ec9df8193cf0 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -197,14 +197,38 @@ def combine_inputs(*named_input_lists: List[Tuple[str, List[Tuple[str, Any]]]]) return list(combined_dict.items()) -class MultiPipelineBlocks: +class AutoPipelineBlocks: """ - A class that combines multiple pipeline block classes into one. When used, it has same API and properties as - PipelineBlock. And it can be used in ModularPipeline as a single pipeline block. + A class that automatically selects a block to run based on the inputs. + + Attributes: + block_classes: List of block classes to be used + block_names: List of prefixes for each block + block_trigger_inputs: List of input names that trigger specific blocks, with None for default """ block_classes = [] - block_prefixes = [] + block_names = [] + block_trigger_inputs = [] + + def __init__(self): + blocks = OrderedDict() + for block_name, block_cls in zip(self.block_names, self.block_classes): + blocks[block_name] = block_cls() + self.blocks = blocks + if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)): + raise ValueError(f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same.") + default_blocks = [t for t in self.block_trigger_inputs if t is None] + if len(default_blocks) > 1 or ( + len(default_blocks) == 1 and self.block_trigger_inputs[-1] is not None + ): + raise ValueError( + f"In {self.__class__.__name__}, exactly one None must be specified as the last element " + "in block_trigger_inputs." + ) + + # Map trigger inputs to block objects + self.trigger_to_block_map = dict(zip(self.block_trigger_inputs, self.blocks.values())) @property def model_name(self): @@ -228,13 +252,6 @@ def expected_configs(self): expected_configs.append(config) return expected_configs - def __init__(self): - blocks = OrderedDict() - for block_prefix, block_cls in zip(self.block_prefixes, self.block_classes): - block_name = f"{block_prefix}_step" if block_prefix != "" else "step" - blocks[block_name] = block_cls() - self.blocks = blocks - # YiYi TODO: address the case where multiple blocks have the same component/auxiliary/config; give out warning etc @property def components(self): @@ -265,60 +282,6 @@ def configs(self): configs.update(block.configs) return configs - @property - def inputs(self) -> List[Tuple[str, Any]]: - raise NotImplementedError("inputs property must be implemented in subclasses") - - @property - def intermediates_inputs(self) -> List[str]: - raise NotImplementedError("intermediates_inputs property must be implemented in subclasses") - - @property - def intermediates_outputs(self) -> List[str]: - raise NotImplementedError("intermediates_outputs property must be implemented in subclasses") - - def __call__(self, pipeline, state): - raise NotImplementedError("__call__ method must be implemented in subclasses") - - - - -# YiYi TODO: remove the trigger input logic and keep it more flexible and less convenient: -# user will need to explicitly write the dispatch logic in __call__ for each subclass of this -class AutoPipelineBlocks(MultiPipelineBlocks): - """ - A class that automatically selects which block to run based on trigger inputs. - - Attributes: - block_classes: List of block classes to be used - block_prefixes: List of prefixes for each block - block_trigger_inputs: List of input names that trigger specific blocks, with None for default - """ - - block_classes = [] - block_prefixes = [] - block_trigger_inputs = [] - - def __init__(self): - super().__init__() - self.__post_init__() - - def __post_init__(self): - """ - Create mapping of trigger inputs directly to block objects. Validates that there is at most one default block - (None trigger). - """ - # Check for at most one default block - default_blocks = [t for t in self.block_trigger_inputs if t is None] - if len(default_blocks) > 1: - raise ValueError( - f"Multiple default blocks specified in {self.__class__.__name__}. " - "Must include at most one None in block_trigger_inputs." - ) - - # Map trigger inputs to block objects - self.trigger_to_block_map = dict(zip(self.block_trigger_inputs, self.blocks.values())) - @property def inputs(self) -> List[Tuple[str, Any]]: named_inputs = [(name, block.inputs) for name, block in self.blocks.items()] @@ -335,30 +298,15 @@ def intermediates_outputs(self) -> List[str]: @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: # Find default block first (if any) - default_block = self.trigger_to_block_map.get(None) - - # Check which trigger inputs are present - active_triggers = [ - input_name - for input_name in self.block_trigger_inputs - if input_name is not None and state.get_input(input_name) is not None - ] - - # If multiple triggers are active, raise error - if len(active_triggers) > 1: - trigger_names = [f"'{t}'" for t in active_triggers] - raise ValueError( - f"Multiple trigger inputs found ({', '.join(trigger_names)}). " - f"Only one trigger input can be provided for {self.__class__.__name__}." - ) - # Get the block to run (use default if no triggers active) - block = self.trigger_to_block_map.get(active_triggers[0]) if active_triggers else default_block - if block is None: - logger.warning(f"No valid block found in {self.__class__.__name__}, skipping.") - return pipeline, state + block = self.trigger_to_block_map.get(None) + for input_name in self.block_trigger_inputs: + if input_name is not None and state.get_input(input_name) is not None: + block = self.trigger_to_block_map[input_name] + break try: + logger.info(f"Running block: {block.__class__.__name__}, trigger: {input_name}") return block(pipeline, state) except Exception as e: error_msg = ( @@ -440,10 +388,70 @@ def __repr__(self): ) -class SequentialPipelineBlocks(MultiPipelineBlocks): +class SequentialPipelineBlocks: """ A class that combines multiple pipeline block classes into one. When called, it will call each block in sequence. """ + block_classes = [] + block_names = [] + + @property + def model_name(self): + return next(iter(self.blocks.values())).model_name + + @property + def expected_components(self): + expected_components = [] + for block in self.blocks.values(): + for component in block.expected_components: + if component not in expected_components: + expected_components.append(component) + return expected_components + + @property + def expected_configs(self): + expected_configs = [] + for block in self.blocks.values(): + for config in block.expected_configs: + if config not in expected_configs: + expected_configs.append(config) + return expected_configs + + def __init__(self): + blocks = OrderedDict() + for block_name, block_cls in zip(self.block_names, self.block_classes): + blocks[block_name] = block_cls() + self.blocks = blocks + + # YiYi TODO: address the case where multiple blocks have the same component/auxiliary/config; give out warning etc + @property + def components(self): + # Combine components from all blocks + components = {} + for block_name, block in self.blocks.items(): + for key, value in block.components.items(): + # Only update if: + # 1. Key doesn't exist yet in components, OR + # 2. New value is not None + if key not in components or value is not None: + components[key] = value + return components + + @property + def auxiliaries(self): + # Combine auxiliaries from all blocks + auxiliaries = {} + for block_name, block in self.blocks.items(): + auxiliaries.update(block.auxiliaries) + return auxiliaries + + @property + def configs(self): + # Combine configs from all blocks + configs = {} + for block_name, block in self.blocks.items(): + configs.update(block.configs) + return configs @property def inputs(self) -> List[Tuple[str, Any]]: @@ -467,7 +475,11 @@ def intermediates_inputs(self) -> List[str]: @property def intermediates_outputs(self) -> List[str]: return list(set().union(*(block.intermediates_outputs for block in self.blocks.values()))) - + + @property + def final_intermediates_outputs(self) -> List[str]: + return next(reversed(self.blocks.values())).intermediates_outputs + @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: for block_name, block in self.blocks.items(): @@ -536,7 +548,8 @@ def __repr__(self): intermediates_str = ( "\n Intermediates:\n" f" - inputs: {', '.join(self.intermediates_inputs)}\n" - f" - outputs: {', '.join(self.intermediates_outputs)}" + f" - outputs: {', '.join(self.intermediates_outputs)}\n" + f" - final outputs: {', '.join(self.final_intermediates_outputs)}" ) return ( @@ -772,7 +785,7 @@ def __repr__(self): output += "Pipeline Block:\n" output += "--------------\n" block = self.pipeline_block - if isinstance(block, MultiPipelineBlocks): + if hasattr(block, "blocks"): output += f"{block.__class__.__name__}\n" # Add sub-blocks information for sub_block_name, sub_block in block.blocks.items(): @@ -787,13 +800,10 @@ def __repr__(self): output += "\n" # Add final intermediate outputs for SequentialPipelineBlocks - if isinstance(block, SequentialPipelineBlocks): - last_block = list(block.blocks.values())[-1] - if hasattr(last_block, "intermediates_outputs"): - final_outputs = last_block.intermediates_outputs - final_intermediates_str = f" (final intermediate outputs: {', '.join(final_outputs)})" - output += f" {final_intermediates_str}\n" - output += "\n" + if hasattr(block, "final_intermediate_output"): + final_intermediates_str = f" (final intermediate outputs: {', '.join(block.final_intermediate_output)})" + output += f" {final_intermediates_str}\n" + output += "\n" # List the components registered in the pipeline output += "Registered Components:\n" diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 1798a11fb0f4..c1154a77dc9f 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -1360,6 +1360,10 @@ def intermediates_inputs(self) -> List[str]: "pooled_prompt_embeds", "negative_pooled_prompt_embeds", "timestep_cond", + "mask", + "masked_image_latents", + "noise", + "image_latents", ] @property @@ -1406,6 +1410,29 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: timestep_cond = state.get_intermediate("timestep_cond") + # inpainting + mask = state.get_intermediate("mask") + masked_image_latents = state.get_intermediate("masked_image_latents") + noise = state.get_intermediate("noise") + image_latents = state.get_intermediate("image_latents") + num_channels_unet = pipeline.unet.config.in_channels + if num_channels_unet == 9: + # default case for runwayml/stable-diffusion-inpainting + if mask is None or masked_image_latents is None: + raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") + num_channels_latents = latents.shape[1] + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {pipeline.unet.config} expects" + f" {pipeline.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + + device = pipeline._execution_device height, width = latents.shape[-2:] @@ -1501,6 +1528,9 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: pooled_prompt_embeds, negative_pooled_prompt_embeds, ) + if num_channels_unet == 9: + mask = pipeline.guider.prepare_input(mask, mask) + masked_image_latents = pipeline.guider.prepare_input(masked_image_latents, masked_image_latents) added_cond_kwargs = { "text_embeds": pooled_prompt_embeds, @@ -1566,8 +1596,12 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: mid_block_res_sample, torch.zeros_like(mid_block_res_sample) ) + latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t) + if num_channels_unet == 9: + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + noise_pred = pipeline.unet( - pipeline.scheduler.scale_model_input(latent_model_input, t), + latent_model_input, t, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, @@ -1587,6 +1621,18 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) + + if num_channels_unet == 4 and mask is not None and image_latents is not None: + init_mask = pipeline.guider._maybe_split_prepared_input(mask)[0] + init_latents_proper = image_latents + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = pipeline.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): progress_bar.update() @@ -1678,31 +1724,44 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLAutoSetTimestepsStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLSetTimestepsStep, StableDiffusionXLImg2ImgSetTimestepsStep] - block_prefixes = ["", "img2img"] - block_trigger_inputs = [None, "image"] + block_classes = [StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLSetTimestepsStep] + block_names = ["img2img", "text2img"] + block_trigger_inputs = ["image", None] class StableDiffusionXLAutoPrepareLatentsStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareLatentsStep] - block_prefixes = ["", "img2img"] - block_trigger_inputs = [None, "image"] + block_classes = [StableDiffusionXLInpaintPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareLatentsStep, StableDiffusionXLPrepareLatentsStep] + block_names = ["inpaint","img2img", "text2img"] + block_trigger_inputs = ["mask_image", "image", None] class StableDiffusionXLAutoPrepareAdditionalConditioningStep(AutoPipelineBlocks): block_classes = [ - StableDiffusionXLPrepareAdditionalConditioningStep, StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, + StableDiffusionXLPrepareAdditionalConditioningStep, ] - block_prefixes = ["", "img2img"] - block_trigger_inputs = [None, "image"] + block_names = ["img2img", "text2img"] + block_trigger_inputs = ["image", None] class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLDenoiseStep, StableDiffusionXLControlNetDenoiseStep] - block_prefixes = ["", "controlnet"] - block_trigger_inputs = [None, "control_image"] + block_classes = [StableDiffusionXLControlNetDenoiseStep, StableDiffusionXLDenoiseStep] + block_names = ["controlnet", "unet"] + block_trigger_inputs = ["control_image", None] + + +class StableDiffusionXLDecodeStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLDecodeLatentsStep, StableDiffusionXLOutputStep] + block_names = ["decode", "output"] + +class StableDiffusionXLInpaintDecodeStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLDecodeLatentsStep, StableDiffusionXLInpaintOverlayMaskStep, StableDiffusionXLOutputStep] + block_names = ["decode", "mask_overlay", "output"] +class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintDecodeStep, StableDiffusionXLDecodeStep] + block_names = ["inpaint", "non-inpaint"] + block_trigger_inputs = ["padding_mask_crop", None] class StableDiffusionXLAllSteps(SequentialPipelineBlocks): block_classes = [ @@ -1712,18 +1771,16 @@ class StableDiffusionXLAllSteps(SequentialPipelineBlocks): StableDiffusionXLAutoPrepareLatentsStep, StableDiffusionXLAutoPrepareAdditionalConditioningStep, StableDiffusionXLAutoDenoiseStep, - StableDiffusionXLDecodeLatentsStep, - StableDiffusionXLOutputStep + StableDiffusionXLAutoDecodeStep ] - block_prefixes = [ + block_names = [ "input", "text_encoder", "set_timesteps", "prepare_latents", "prepare_add_cond", "denoise", - "decode_latents", - "output" + "decode" ] From db94ca882d2155387a9bf5fb77d7313b5e9c6a71 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 7 Jan 2025 20:49:58 +0100 Subject: [PATCH 040/170] add controlnet inpaint + more refactor --- src/diffusers/pipelines/modular_pipeline.py | 4 + .../pipeline_stable_diffusion_xl_modular.py | 235 ++++++++++-------- 2 files changed, 141 insertions(+), 98 deletions(-) diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index ec9df8193cf0..3fa45d2cbc14 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -305,6 +305,10 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: block = self.trigger_to_block_map[input_name] break + if block is None: + logger.warning(f"skipping auto block: {self.__class__.__name__}") + return pipeline, state + try: logger.info(f"Running block: {block.__class__.__name__}, trigger: {input_name}") return block(pipeline, state) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index c1154a77dc9f..775119b564d9 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -17,6 +17,7 @@ import PIL import torch +from collections import OrderedDict from ...guider import CFGGuider from ...image_processor import VaeImageProcessor @@ -122,64 +123,6 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class StableDiffusionXLOutputStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [("return_dict", True)] - - @property - def intermediates_outputs(self) -> List[str]: - return ["images"] - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - images = state.get_intermediate("images") - return_dict = state.get_input("return_dict") - - if not return_dict: - output = (images,) - else: - output = StableDiffusionXLPipelineOutput(images=images) - state.add_output("images", output) - return pipeline, state - - -class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - ("image", None), - ("mask_image", None), - ("padding_mask_crop", None), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return ["crops_coords", "images"] - - @property - def intermediates_outputs(self) -> List[str]: - return ["images"] - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - original_image = state.get_input("image") - padding_mask_crop = state.get_input("padding_mask_crop") - mask_image = state.get_input("mask_image") - images = state.get_intermediate("images") - crops_coords = state.get_intermediate("crops_coords") - - if padding_mask_crop is not None and crops_coords is not None: - images = [pipeline.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in images] - - state.add_intermediate("images", images) - - return pipeline, state - class StableDiffusionXLInputStep(PipelineBlock): model_name = "stable-diffusion-xl" @@ -376,7 +319,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: return pipeline, state -class StableDiffusionXLVAEEncoderStep(PipelineBlock): +class StableDiffusionXLVaeEncoderStep(PipelineBlock): expected_components = ["vae"] model_name = "stable-diffusion-xl" @@ -589,7 +532,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: return pipeline, state -class StableDiffusionXLInpaintVaeEncodeStep(PipelineBlock): +class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): expected_components = ["vae"] model_name = "stable-diffusion-xl" @@ -694,7 +637,6 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin return pipeline, state -# inpaint-specific class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): expected_components = ["scheduler"] model_name = "stable-diffusion-xl" @@ -804,18 +746,15 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): @property def inputs(self) -> List[Tuple[str, Any]]: return [ - ("height", None), - ("width", None), ("generator", None), ("latents", None), ("num_images_per_prompt", 1), - ("image", None), ("denoising_start", None), ] @property def intermediates_inputs(self) -> List[str]: - return ["batch_size", "dtype", "latent_timestep"] + return ["batch_size", "dtype", "latent_timestep", "image_latents"] @property def intermediates_outputs(self) -> List[str]: @@ -823,8 +762,6 @@ def intermediates_outputs(self) -> List[str]: def __init__(self): super().__init__() - self.auxiliaries["image_processor"] = VaeImageProcessor() - self.components["vae"] = None self.components["scheduler"] = None @torch.no_grad() @@ -834,24 +771,22 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin generator = state.get_input("generator") # image to image only - image = state.get_input("image") denoising_start = state.get_input("denoising_start") batch_size = state.get_intermediate("batch_size") dtype = state.get_intermediate("dtype") # image to image only latent_timestep = state.get_intermediate("latent_timestep") + image_latents = state.get_intermediate("image_latents") if dtype is None: dtype = pipeline.vae.dtype device = pipeline._execution_device - - image = pipeline.image_processor.preprocess(image) add_noise = True if denoising_start is None else False if latents is None: latents = pipeline.prepare_latents_img2img( - image, + image_latents, latent_timestep, batch_size, num_images_per_prompt, @@ -1723,6 +1658,81 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: return pipeline, state +class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + ("image", None), + ("mask_image", None), + ("padding_mask_crop", None), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return ["crops_coords", "images"] + + @property + def intermediates_outputs(self) -> List[str]: + return ["images"] + + @torch.no_grad() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + original_image = state.get_input("image") + padding_mask_crop = state.get_input("padding_mask_crop") + mask_image = state.get_input("mask_image") + images = state.get_intermediate("images") + crops_coords = state.get_intermediate("crops_coords") + + if padding_mask_crop is not None and crops_coords is not None: + images = [pipeline.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in images] + + state.add_intermediate("images", images) + + return pipeline, state + + +class StableDiffusionXLOutputStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [("return_dict", True)] + + @property + def intermediates_outputs(self) -> List[str]: + return ["images"] + + @torch.no_grad() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + images = state.get_intermediate("images") + return_dict = state.get_input("return_dict") + + if not return_dict: + output = (images,) + else: + output = StableDiffusionXLPipelineOutput(images=images) + state.add_output("images", output) + return pipeline, state + + +class StableDiffusionXLDecodeStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLDecodeLatentsStep, StableDiffusionXLOutputStep] + block_names = ["decode", "output"] + + +class StableDiffusionXLInpaintDecodeStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLDecodeLatentsStep, StableDiffusionXLInpaintOverlayMaskStep, StableDiffusionXLOutputStep] + block_names = ["decode", "mask_overlay", "output"] + + +class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintVaeEncoderStep, StableDiffusionXLVaeEncoderStep] + block_names = ["inpaint", "img2img"] + block_trigger_inputs = ["mask_image", "image"] + + class StableDiffusionXLAutoSetTimestepsStep(AutoPipelineBlocks): block_classes = [StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLSetTimestepsStep] block_names = ["img2img", "text2img"] @@ -1750,38 +1760,67 @@ class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): block_trigger_inputs = ["control_image", None] -class StableDiffusionXLDecodeStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLDecodeLatentsStep, StableDiffusionXLOutputStep] - block_names = ["decode", "output"] - -class StableDiffusionXLInpaintDecodeStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLDecodeLatentsStep, StableDiffusionXLInpaintOverlayMaskStep, StableDiffusionXLOutputStep] - block_names = ["decode", "mask_overlay", "output"] - class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks): block_classes = [StableDiffusionXLInpaintDecodeStep, StableDiffusionXLDecodeStep] block_names = ["inpaint", "non-inpaint"] block_trigger_inputs = ["padding_mask_crop", None] -class StableDiffusionXLAllSteps(SequentialPipelineBlocks): - block_classes = [ - StableDiffusionXLInputStep, - StableDiffusionXLTextEncoderStep, - StableDiffusionXLAutoSetTimestepsStep, - StableDiffusionXLAutoPrepareLatentsStep, - StableDiffusionXLAutoPrepareAdditionalConditioningStep, - StableDiffusionXLAutoDenoiseStep, - StableDiffusionXLAutoDecodeStep - ] - block_names = [ - "input", - "text_encoder", - "set_timesteps", - "prepare_latents", - "prepare_add_cond", - "denoise", - "decode" - ] + +TEXT2IMAGE_BLOCKS = OrderedDict([ + ("input", StableDiffusionXLInputStep), + ("text_encoder", StableDiffusionXLTextEncoderStep), + ("set_timesteps", StableDiffusionXLAutoSetTimestepsStep), + ("prepare_latents", StableDiffusionXLAutoPrepareLatentsStep), + ("prepare_add_cond", StableDiffusionXLAutoPrepareAdditionalConditioningStep), + ("denoise", StableDiffusionXLAutoDenoiseStep), + ("decode", StableDiffusionXLDecodeStep) +]) + +IMAGE2IMAGE_BLOCKS = OrderedDict([ + ("input", StableDiffusionXLInputStep), + ("text_encoder", StableDiffusionXLTextEncoderStep), + ("image_encoder", StableDiffusionXLVaeEncoderStep), + ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), + ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep), + ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), + ("denoise", StableDiffusionXLDenoiseStep), + ("decode", StableDiffusionXLDecodeStep) +]) + +INPAINT_BLOCKS = OrderedDict([ + ("input", StableDiffusionXLInputStep), + ("text_encoder", StableDiffusionXLTextEncoderStep), + ("image_encoder", StableDiffusionXLInpaintVaeEncoderStep), + ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), + ("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep), + ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), + ("denoise", StableDiffusionXLDenoiseStep), + ("decode", StableDiffusionXLInpaintDecodeStep) +]) + +CONTROLNET_BLOCKS = OrderedDict([ + ("denoise", StableDiffusionXLControlNetDenoiseStep), +]) + +AUTO_BLOCKS = OrderedDict([ + ("input", StableDiffusionXLInputStep), + ("text_encoder", StableDiffusionXLTextEncoderStep), + ("image_encoder", StableDiffusionXLAutoVaeEncoderStep), + ("set_timesteps", StableDiffusionXLAutoSetTimestepsStep), + ("prepare_latents", StableDiffusionXLAutoPrepareLatentsStep), + ("prepare_add_cond", StableDiffusionXLAutoPrepareAdditionalConditioningStep), + ("denoise", StableDiffusionXLAutoDenoiseStep), + ("decode", StableDiffusionXLAutoDecodeStep) +]) + + +SDXL_SUPPORTED_BLOCKS = { + "text2img": TEXT2IMAGE_BLOCKS, + "img2img": IMAGE2IMAGE_BLOCKS, + "inpaint": INPAINT_BLOCKS, + "controlnet": CONTROLNET_BLOCKS, + "auto": AUTO_BLOCKS +} class StableDiffusionXLModularPipeline( From e973de64f93759ab9f31b7ed0585c59515f5dab1 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 8 Jan 2025 21:47:20 +0100 Subject: [PATCH 041/170] fix contro;net inpaint preprocess --- .../pipeline_stable_diffusion_xl_modular.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 775119b564d9..e4b4295ef7d1 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -1299,6 +1299,7 @@ def intermediates_inputs(self) -> List[str]: "masked_image_latents", "noise", "image_latents", + "crops_coords", ] @property @@ -1350,6 +1351,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: masked_image_latents = state.get_intermediate("masked_image_latents") noise = state.get_intermediate("noise") image_latents = state.get_intermediate("image_latents") + crops_coords = state.get_intermediate("crops_coords") num_channels_unet = pipeline.unet.config.in_channels if num_channels_unet == 9: # default case for runwayml/stable-diffusion-inpainting @@ -1409,6 +1411,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, + crops_coords=crops_coords, ) elif isinstance(controlnet, MultiControlNetModel): control_images = [] @@ -1422,6 +1425,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, + crops_coords=crops_coords, ) control_images.append(control_image) @@ -1947,7 +1951,8 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state return image_embeds, uncond_image_embeds # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image - # return image without apply any guidance + # 1. return image without apply any guidance + # 2. add crops_coords and resize_mode to preprocess() def prepare_control_image( self, image, @@ -1957,8 +1962,12 @@ def prepare_control_image( num_images_per_prompt, device, dtype, - ): - image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + crops_coords=None, + ): + if crops_coords is not None: + image = self.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) + else: + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) image_batch_size = image.shape[0] if image_batch_size == 1: From 7a34832d5254696a5dc26724288d60dc63f67e9a Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 9 Jan 2025 20:29:45 +0000 Subject: [PATCH 042/170] [modular] Stable Diffusion XL ControlNet Union (#10509) StableDiffusionXLControlNetUnionDenoiseStep --- .../pipeline_stable_diffusion_xl_modular.py | 287 ++++++++++++++++++ 1 file changed, 287 insertions(+) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index e4b4295ef7d1..2b0fb9ae6670 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -1582,6 +1582,293 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: return pipeline, state +class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): + expected_components = ["unet", "controlnet", "scheduler", "guider", "controlnet_guider"] + model_name = "stable-diffusion-xl" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + ("control_image", None), + ("control_guidance_start", 0.0), + ("control_guidance_end", 1.0), + ("controlnet_conditioning_scale", 1.0), + ("control_mode", 0), + ("guess_mode", False), + ("num_images_per_prompt", 1), + ("guidance_scale", 5.0), + ("guidance_rescale", 0.0), + ("cross_attention_kwargs", None), + ("generator", None), + ("eta", 0.0), + ("guider_kwargs", None), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + "latents", + "batch_size", + "timesteps", + "num_inference_steps", + "prompt_embeds", + "negative_prompt_embeds", + "add_time_ids", + "negative_add_time_ids", + "pooled_prompt_embeds", + "negative_pooled_prompt_embeds", + "timestep_cond", + "mask", + "noise", + "image_latents", + "crops_coords", + ] + + @property + def intermediates_outputs(self) -> List[str]: + return ["latents"] + + def __init__(self): + super().__init__() + self.components["guider"] = CFGGuider() + self.components["controlnet_guider"] = CFGGuider() + self.components["scheduler"] = None + self.components["unet"] = None + self.components["controlnet"] = None + control_image_processor = VaeImageProcessor(do_convert_rgb=True, do_normalize=False) + self.auxiliaries["control_image_processor"] = control_image_processor + + @torch.no_grad() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + guidance_scale = state.get_input("guidance_scale") + guidance_rescale = state.get_input("guidance_rescale") + cross_attention_kwargs = state.get_input("cross_attention_kwargs") + guider_kwargs = state.get_input("guider_kwargs") + generator = state.get_input("generator") + eta = state.get_input("eta") + num_images_per_prompt = state.get_input("num_images_per_prompt") + # controlnet-specific inputs + control_image = state.get_input("control_image") + control_guidance_start = state.get_input("control_guidance_start") + control_guidance_end = state.get_input("control_guidance_end") + controlnet_conditioning_scale = state.get_input("controlnet_conditioning_scale") + control_mode = state.get_input("control_mode") + guess_mode = state.get_input("guess_mode") + + batch_size = state.get_intermediate("batch_size") + latents = state.get_intermediate("latents") + timesteps = state.get_intermediate("timesteps") + num_inference_steps = state.get_intermediate("num_inference_steps") + + prompt_embeds = state.get_intermediate("prompt_embeds") + negative_prompt_embeds = state.get_intermediate("negative_prompt_embeds") + pooled_prompt_embeds = state.get_intermediate("pooled_prompt_embeds") + negative_pooled_prompt_embeds = state.get_intermediate("negative_pooled_prompt_embeds") + add_time_ids = state.get_intermediate("add_time_ids") + negative_add_time_ids = state.get_intermediate("negative_add_time_ids") + + timestep_cond = state.get_intermediate("timestep_cond") + + # inpainting + mask = state.get_intermediate("mask") + noise = state.get_intermediate("noise") + image_latents = state.get_intermediate("image_latents") + crops_coords = state.get_intermediate("crops_coords") + + device = pipeline._execution_device + + height, width = latents.shape[-2:] + height = height * pipeline.vae_scale_factor + width = width * pipeline.vae_scale_factor + + # prepare controlnet inputs + controlnet = pipeline.controlnet._orig_mod if is_compiled_module(pipeline.controlnet) else pipeline.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + + global_pool_conditions = controlnet.config.global_pool_conditions + guess_mode = guess_mode or global_pool_conditions + + num_control_type = controlnet.config.num_control_type + + if not isinstance(control_image, list): + control_image = [control_image] + + if not isinstance(control_mode, list): + control_mode = [control_mode] + + if len(control_image) != len(control_mode): + raise ValueError("Expected len(control_image) == len(control_type)") + + control_type = [0 for _ in range(num_control_type)] + for control_idx in control_mode: + control_type[control_idx] = 1 + + control_type = torch.Tensor(control_type) + + for idx, _ in enumerate(control_image): + control_image[idx] = pipeline.prepare_control_image( + image=control_image[idx], + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + crops_coords=crops_coords, + ) + height, width = control_image[idx].shape[-2:] + + controlnet_keep = [] + for i in range(len(timesteps)): + controlnet_keep.append( + 1.0 + - float(i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end) + ) + + # Prepare conditional inputs for unet using the guider + # adding default guider arguments: disable_guidance, guidance_scale, guidance_rescale + disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False + guider_kwargs = guider_kwargs or {} + guider_kwargs = { + **guider_kwargs, + "disable_guidance": disable_guidance, + "guidance_scale": guidance_scale, + "guidance_rescale": guidance_rescale, + "batch_size": batch_size, + } + pipeline.guider.set_guider(pipeline, guider_kwargs) + prompt_embeds = pipeline.guider.prepare_input( + prompt_embeds, + negative_prompt_embeds, + ) + add_time_ids = pipeline.guider.prepare_input( + add_time_ids, + negative_add_time_ids, + ) + pooled_prompt_embeds = pipeline.guider.prepare_input( + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + added_cond_kwargs = { + "text_embeds": pooled_prompt_embeds, + "time_ids": add_time_ids, + } + + # Prepare conditional inputs for controlnet using the guider + controlnet_disable_guidance = True if disable_guidance or guess_mode else False + controlnet_guider_kwargs = guider_kwargs or {} + controlnet_guider_kwargs = { + **controlnet_guider_kwargs, + "disable_guidance": controlnet_disable_guidance, + "guidance_scale": guidance_scale, + "guidance_rescale": guidance_rescale, + "batch_size": batch_size, + } + pipeline.controlnet_guider.set_guider(pipeline, controlnet_guider_kwargs) + controlnet_prompt_embeds = pipeline.controlnet_guider.prepare_input(prompt_embeds) + controlnet_added_cond_kwargs = { + "text_embeds": pipeline.controlnet_guider.prepare_input(pooled_prompt_embeds), + "time_ids": pipeline.controlnet_guider.prepare_input(add_time_ids), + } + for idx, _ in enumerate(control_image): + control_image[idx] = pipeline.controlnet_guider.prepare_input(control_image[idx], control_image[idx]) + + # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta) + num_warmup_steps = max(len(timesteps) - num_inference_steps * pipeline.scheduler.order, 0) + + control_type = ( + control_type.reshape(1, -1) + .to(device, dtype=prompt_embeds.dtype) + .repeat(batch_size * num_images_per_prompt * 2, 1) + ) + with pipeline.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # prepare latents for unet using the guider + latent_model_input = pipeline.guider.prepare_input(latents, latents) + + # prepare latents for controlnet using the guider + control_model_input = pipeline.controlnet_guider.prepare_input(latents, latents) + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + down_block_res_samples, mid_block_res_sample = pipeline.controlnet( + pipeline.scheduler.scale_model_input(control_model_input, t), + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=control_image, + control_type=control_type, + control_type_idx=control_mode, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + ) + + # when we apply guidance for unet, but not for controlnet: + # add 0 to the unconditional batch + down_block_res_samples = pipeline.guider.prepare_input( + down_block_res_samples, [torch.zeros_like(d) for d in down_block_res_samples] + ) + mid_block_res_sample = pipeline.guider.prepare_input( + mid_block_res_sample, torch.zeros_like(mid_block_res_sample) + ) + + latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t) + + noise_pred = pipeline.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + return_dict=False, + )[0] + # perform guidance + noise_pred = pipeline.guider.apply_guidance(noise_pred, timestep=t, latents=latents) + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = pipeline.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if mask is not None and image_latents is not None: + init_mask = pipeline.guider._maybe_split_prepared_input(mask)[0] + init_latents_proper = image_latents + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = pipeline.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): + progress_bar.update() + + pipeline.guider.reset_guider(pipeline) + pipeline.controlnet_guider.reset_guider(pipeline) + state.add_intermediate("latents", latents) + + return pipeline, state + class StableDiffusionXLDecodeLatentsStep(PipelineBlock): expected_components = ["vae"] model_name = "stable-diffusion-xl" From 2220af6940b6a38effeb4825f1b882bfe66936e7 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 11 Jan 2025 09:05:47 +0100 Subject: [PATCH 043/170] refactor --- src/diffusers/pipelines/modular_pipeline.py | 387 +++++++++++++----- .../pipeline_stable_diffusion_xl_modular.py | 34 +- 2 files changed, 314 insertions(+), 107 deletions(-) diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index 3fa45d2cbc14..30e237d9c675 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -131,37 +131,57 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: def __repr__(self): class_name = self.__class__.__name__ + base_class = self.__class__.__bases__[0].__name__ - # Components section + # Components section - group into main components and auxiliaries if needed expected_components = set(getattr(self, "expected_components", [])) loaded_components = set(self.components.keys()) all_components = sorted(expected_components | loaded_components) - components = ", ".join( - f"{k}={type(self.components[k]).__name__}" if k in loaded_components else f"{k}" for k in all_components - ) + + main_components = [] + auxiliary_components = [] + for k in all_components: + component_str = f" - {k}={type(self.components[k]).__name__}" if k in loaded_components else f" - {k}" + if k in getattr(self, "auxiliary_components", []): + auxiliary_components.append(component_str) + else: + main_components.append(component_str) + + components = "Components:\n" + "\n".join(main_components) + if auxiliary_components: + components += "\n Auxiliaries:\n" + "\n".join(auxiliary_components) # Configs section expected_configs = set(getattr(self, "expected_configs", [])) loaded_configs = set(self.configs.keys()) all_configs = sorted(expected_configs | loaded_configs) - configs = ", ".join(f"{k}={self.configs[k]}" if k in loaded_configs else f"{k}" for k in all_configs) + configs = "Configs:\n" + "\n".join( + f" - {k}={self.configs[k]}" if k in loaded_configs else f" - {k}" + for k in all_configs + ) - # Single block shows itself - blocks = f"step={self.__class__.__name__}" + # Inputs section + inputs = "inputs: " + ", ".join( + f"{name}={default}" if default is not None else name + for name, default in self.inputs + ) - # Other information - inputs = ", ".join(f"{name}={default}" for name, default in self.inputs) - intermediates_inputs = ", ".join(self.intermediates_inputs) - intermediates_outputs = ", ".join(self.intermediates_outputs) + # Intermediates section + input_set = set(self.intermediates_inputs) + output_set = set(self.intermediates_outputs) + + modified_inputs = [f"{item}*" for item in self.intermediates_inputs] + new_outputs = [item for item in self.intermediates_outputs if item not in input_set] + + intermediates = f"intermediates: {', '.join(modified_inputs)} -> {', '.join(new_outputs)}" return ( f"{class_name}(\n" - f" components: {components}\n" - f" configs: {configs}\n" - f" blocks: {blocks}\n" - f" inputs: {inputs}\n" - f" intermediates_inputs: {intermediates_inputs}\n" - f" intermediates_outputs: {intermediates_outputs}\n" + f" Class: {base_class}\n" + f" {components}\n" + f" {configs}\n" + f" {inputs}\n" + f" {intermediates}\n" f")" ) @@ -175,7 +195,6 @@ def combine_inputs(*named_input_lists: List[Tuple[str, List[Tuple[str, Any]]]]) named_input_lists: List of tuples containing (block_name, input_list) pairs """ combined_dict = {} - # Track which block provided which value value_sources = {} for block_name, inputs in named_input_lists: @@ -229,6 +248,7 @@ def __init__(self): # Map trigger inputs to block objects self.trigger_to_block_map = dict(zip(self.block_trigger_inputs, self.blocks.values())) + self.block_to_trigger_map = dict(zip(self.blocks.keys(), self.block_trigger_inputs)) @property def model_name(self): @@ -321,73 +341,99 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: logger.error(error_msg) raise - def __repr__(self): - class_name = self.__class__.__name__ - - # Components section - expected_components = set(getattr(self, "expected_components", [])) - loaded_components = set(self.components.keys()) - all_components = sorted(expected_components | loaded_components) - components_str = " Components:\n" + "\n".join( - f" - {k}={type(self.components[k]).__name__}" if k in loaded_components else f" - {k}" - for k in all_components - ) - - # Auxiliaries section - auxiliaries_str = " Auxiliaries:\n" + "\n".join( - f" - {k}={type(v).__name__}" for k, v in self.auxiliaries.items() - ) - - # Configs section - expected_configs = set(getattr(self, "expected_configs", [])) - loaded_configs = set(self.configs.keys()) - all_configs = sorted(expected_configs | loaded_configs) - configs_str = " Configs:\n" + "\n".join( - f" - {k}={self.configs[k]}" if k in loaded_configs else f" - {k}" for k in all_configs - ) - - # Blocks section with trigger information - blocks_str = " Blocks:\n" - for name, block in self.blocks.items(): - # Find trigger for this block - trigger = next((t for t, b in self.trigger_to_block_map.items() if b == block), None) - trigger_str = " (default)" if trigger is None else f" (triggered by: {trigger})" + def _get_trigger_inputs(self): + """ + Returns a set of all unique trigger input values found in the blocks. + Returns: Set[str] containing all unique block_trigger_inputs values + """ + def fn_recursive_get_trigger(blocks): + trigger_values = set() + + if blocks is not None: + for name, block in blocks.items(): + # Check if current block has trigger inputs(i.e. auto block) + if hasattr(block, 'block_trigger_inputs') and block.block_trigger_inputs is not None: + # Add all non-None values from the trigger inputs list + trigger_values.update(t for t in block.block_trigger_inputs if t is not None) + + # If block has blocks, recursively check them + if hasattr(block, 'blocks'): + nested_triggers = fn_recursive_get_trigger(block.blocks) + trigger_values.update(nested_triggers) + + return trigger_values + + trigger_inputs = set(self.block_trigger_inputs) + trigger_inputs.update(fn_recursive_get_trigger(self.blocks)) + + return trigger_inputs - blocks_str += f" {name} ({block.__class__.__name__}){trigger_str}\n" + @property + def trigger_inputs(self): + return self._get_trigger_inputs() - # Add inputs information + def __repr__(self): + class_name = self.__class__.__name__ + base_class = self.__class__.__bases__[0].__name__ + + all_triggers = set(self.trigger_to_block_map.keys()) + + sections = [] + for trigger in sorted(all_triggers, key=lambda x: str(x)): + sections.append(f"\n Trigger Input: {trigger}\n") + + block = self.trigger_to_block_map.get(trigger) + if block is None: + continue + + expected_components = set(getattr(block, "expected_components", [])) + loaded_components = set(k for k, v in self.components.items() + if v is not None and hasattr(block, k)) + all_components = sorted(expected_components | loaded_components) + if all_components: + sections.append(" Components:\n" + "\n".join( + f" - {k}={type(self.components[k]).__name__}" if k in loaded_components + else f" - {k}" for k in all_components + )) + + if self.auxiliaries: + sections.append(" Auxiliaries:\n" + "\n".join( + f" - {k}={type(v).__name__}" + for k, v in self.auxiliaries.items() + )) + + if self.configs: + sections.append(" Configs:\n" + "\n".join( + f" - {k}={v}" for k, v in self.configs.items() + )) + + sections.append(f" Block: {block.__class__.__name__}") + if hasattr(block, "inputs"): - inputs_str = ", ".join(f"{name}={default}" for name, default in block.inputs) + inputs_str = ", ".join( + name if default is None else f"{name}={default}" + for name, default in block.inputs + ) if inputs_str: - blocks_str += f" inputs: {inputs_str}\n" + sections.append(f" inputs: {inputs_str}") - # Add intermediates information if hasattr(block, "intermediates_inputs") or hasattr(block, "intermediates_outputs"): intermediates_str = "" if hasattr(block, "intermediates_inputs"): intermediates_str += f"{', '.join(block.intermediates_inputs)}" - if hasattr(block, "intermediates_outputs"): if intermediates_str: intermediates_str += " -> " intermediates_str += f"{', '.join(block.intermediates_outputs)}" - if intermediates_str: - blocks_str += f" intermediates: {intermediates_str}\n" - blocks_str += "\n" - intermediates_str = ( - "\n Intermediates:\n" - f" - inputs: {', '.join(self.intermediates_inputs)}\n" - f" - outputs: {', '.join(self.intermediates_outputs)}" - ) + sections.append(f" intermediates: {intermediates_str}") + + sections.append("") return ( f"{class_name}(\n" - f"{components_str}\n" - f"{auxiliaries_str}\n" - f"{configs_str}\n" - f"{blocks_str}\n" - f"{intermediates_str}\n" + f" Class: {base_class}\n" + f"{chr(10).join(sections)}" f")" ) @@ -421,6 +467,22 @@ def expected_configs(self): expected_configs.append(config) return expected_configs + @classmethod + def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "SequentialPipelineBlocks": + """Creates a SequentialPipelineBlocks instance from a dictionary of blocks. + + Args: + blocks_dict: Dictionary mapping block names to block instances + + Returns: + A new SequentialPipelineBlocks instance + """ + instance = cls() + instance.block_classes = [block.__class__ for block in blocks_dict.values()] + instance.block_names = list(blocks_dict.keys()) + instance.blocks = blocks_dict + return instance + def __init__(self): blocks = OrderedDict() for block_name, block_cls in zip(self.block_names, self.block_classes): @@ -498,9 +560,120 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: logger.error(error_msg) raise return pipeline, state + + def _get_trigger_inputs(self): + """ + Returns a set of all unique trigger input values found in the blocks. + Returns: Set[str] containing all unique block_trigger_inputs values + """ + def fn_recursive_get_trigger(blocks): + trigger_values = set() + + if blocks is not None: + for name, block in blocks.items(): + # Check if current block has trigger inputs(i.e. auto block) + if hasattr(block, 'block_trigger_inputs') and block.block_trigger_inputs is not None: + # Add all non-None values from the trigger inputs list + trigger_values.update(t for t in block.block_trigger_inputs if t is not None) + + # If block has blocks, recursively check them + if hasattr(block, 'blocks'): + nested_triggers = fn_recursive_get_trigger(block.blocks) + trigger_values.update(nested_triggers) + + return trigger_values + + return fn_recursive_get_trigger(self.blocks) + @property + def trigger_inputs(self): + return self._get_trigger_inputs() + + def _traverse_trigger_blocks(self, trigger_inputs): + + def fn_recursive_traverse(block, block_name, trigger_inputs): + result_blocks = OrderedDict() + # sequential or PipelineBlock + if not hasattr(block, 'block_trigger_inputs'): + if hasattr(block, 'blocks'): + # sequential + for block_name, block in block.blocks.items(): + blocks_to_update = fn_recursive_traverse(block, block_name, trigger_inputs) + result_blocks.update(blocks_to_update) + else: + # PipelineBlock + result_blocks[block_name] = block + return result_blocks + + # auto + else: + # Find first block_trigger_input that matches any value in our trigger_value tuple + this_block = None + for trigger_input in block.block_trigger_inputs: + if trigger_input is not None and trigger_input in trigger_inputs: + this_block = block.trigger_to_block_map[trigger_input] + break + + # If no matches found, try to get the default (None) block + if this_block is None and None in block.block_trigger_inputs: + this_block = block.trigger_to_block_map[None] + + if this_block is not None: + # sequential/auto + if hasattr(this_block, 'blocks'): + result_blocks.update(fn_recursive_traverse(this_block, block_name, trigger_inputs)) + else: + # PipelineBlock + result_blocks[block_name] = this_block + + return result_blocks + + all_blocks = OrderedDict() + for block_name, block in self.blocks.items(): + blocks_to_update = fn_recursive_traverse(block, block_name, trigger_inputs) + all_blocks.update(blocks_to_update) + return all_blocks + + def get_triggered_blocks(self, *trigger_inputs): + trigger_inputs_all = self.trigger_inputs + + if trigger_inputs is not None: + + if not isinstance(trigger_inputs, (list, tuple, set)): + trigger_inputs = [trigger_inputs] + invalid_inputs = [x for x in trigger_inputs if x not in trigger_inputs_all] + if invalid_inputs: + logger.warning( + f"The following trigger inputs will be ignored as they are not supported: {invalid_inputs}" + ) + trigger_inputs = [x for x in trigger_inputs if x in trigger_inputs_all] + + if trigger_inputs is None: + if None in trigger_inputs_all: + trigger_inputs = [None] + else: + trigger_inputs = [trigger_inputs_all[0]] + blocks_triggered = self._traverse_trigger_blocks(trigger_inputs) + return SequentialPipelineBlocks.from_blocks_dict(blocks_triggered) + def __repr__(self): class_name = self.__class__.__name__ + base_class = self.__class__.__bases__[0].__name__ + header = ( + f"{class_name}(\n Class: {base_class}\n" + if base_class and base_class != "object" + else f"{class_name}(\n" + ) + + + if self.trigger_inputs: + header += "\n" # Add empty line before + header += " " + "=" * 100 + "\n" # Add decorative line + header += " This pipeline block contains AutoPipelineBlocks where different blocks are dispatched at runtime based on your inputs.\n" + header += " You can use `get_triggered_blocks(input1, input2,...)` to get specific information for your trigger inputs.\n" + header += f" Trigger Inputs: {self.trigger_inputs}\n" + header += " " + "=" * 100 + "\n" # Add decorative line + header += "\n" # Add empty line after # Components section expected_components = set(getattr(self, "expected_components", [])) @@ -524,44 +697,57 @@ def __repr__(self): f" - {k}={self.configs[k]}" if k in loaded_configs else f" - {k}" for k in all_configs ) - # Detailed blocks section with data flow blocks_str = " Blocks:\n" for i, (name, block) in enumerate(self.blocks.items()): blocks_str += f" {i}. {name} ({block.__class__.__name__})\n" - # Add inputs information if hasattr(block, "inputs"): - inputs_str = ", ".join(f"{name}={default}" for name, default in block.inputs) + inputs_str = ", ".join( + name if default is None else f"{name}={default}" + for name, default in block.inputs + ) blocks_str += f" inputs: {inputs_str}\n" - # Add intermediates information if hasattr(block, "intermediates_inputs") or hasattr(block, "intermediates_outputs"): intermediates_str = "" if hasattr(block, "intermediates_inputs"): - intermediates_str += f"{', '.join(block.intermediates_inputs)}" + inputs_set = set(block.intermediates_inputs) + intermediates_str += ", ".join(f"*{inp}" if inp in (getattr(block, "intermediates_outputs", set())) else inp + for inp in block.intermediates_inputs) if hasattr(block, "intermediates_outputs"): if intermediates_str: intermediates_str += " -> " - intermediates_str += f"{', '.join(block.intermediates_outputs)}" + outputs_set = set(block.intermediates_outputs) + new_outputs = outputs_set - inputs_set if hasattr(block, "intermediates_inputs") else outputs_set + intermediates_str += ", ".join(new_outputs) if intermediates_str: blocks_str += f" intermediates: {intermediates_str}\n" blocks_str += "\n" + + inputs_str = " inputs:\n " + ", ".join( + f"{name}={default}" if default is not None else f"{name}" + for name, default in self.inputs + ) + + modified_inputs = [f"*{inp}" if inp in self.intermediates_outputs else inp for inp in self.intermediates_inputs] + new_outputs = [out for out in self.intermediates_outputs if out not in self.intermediates_inputs] intermediates_str = ( - "\n Intermediates:\n" - f" - inputs: {', '.join(self.intermediates_inputs)}\n" - f" - outputs: {', '.join(self.intermediates_outputs)}\n" + "\n Intermediates:\n" + f" - inputs: {', '.join(modified_inputs)}\n" + f" - outputs: {', '.join(new_outputs)}\n" f" - final outputs: {', '.join(self.final_intermediates_outputs)}" ) return ( - f"{class_name}(\n" + f"{header}\n" f"{components_str}\n" f"{auxiliaries_str}\n" f"{configs_str}\n" f"{blocks_str}\n" + f"{inputs_str}\n" f"{intermediates_str}\n" f")" ) @@ -785,30 +971,34 @@ def default_call_parameters(self) -> Dict[str, Any]: def __repr__(self): output = "ModularPipeline:\n" output += "==============================\n\n" + + block = self.pipeline_block + if hasattr(block, "trigger_inputs") and block.trigger_inputs: + output += "\n" + output += " Trigger Inputs:\n" + output += " --------------\n" + output += f" This pipeline contains dynamic blocks that are selected at runtime based on your inputs.\n" + output += f" • Trigger inputs: {block.trigger_inputs}\n" + output += f" • Use .pipeline_block.get_triggered_blocks(*inputs) to see which blocks will be used for specific inputs\n" + output += "\n" output += "Pipeline Block:\n" output += "--------------\n" - block = self.pipeline_block if hasattr(block, "blocks"): output += f"{block.__class__.__name__}\n" - # Add sub-blocks information + base_class = block.__class__.__bases__[0].__name__ + output += f" (Class: {base_class})\n" if base_class != "object" else "\n" for sub_block_name, sub_block in block.blocks.items(): - output += f" • {sub_block_name} ({sub_block.__class__.__name__}) \n" + if hasattr(block, "block_trigger_inputs"): + trigger_input = block.block_to_trigger_map[sub_block_name] + trigger_info = f" [trigger: {trigger_input}]" if trigger_input is not None else " [default]" + output += f" • {sub_block_name} ({sub_block.__class__.__name__}){trigger_info}\n" + else: + output += f" • {sub_block_name} ({sub_block.__class__.__name__})\n" else: output += f"{block.__class__.__name__}\n" output += "\n" - if hasattr(block, "intermediates_outputs"): - intermediates_str = f"-> {', '.join(block.intermediates_outputs)}" - output += f" {intermediates_str}\n" - output += "\n" - - # Add final intermediate outputs for SequentialPipelineBlocks - if hasattr(block, "final_intermediate_output"): - final_intermediates_str = f" (final intermediate outputs: {', '.join(block.final_intermediate_output)})" - output += f" {final_intermediates_str}\n" - output += "\n" - # List the components registered in the pipeline output += "Registered Components:\n" output += "----------------------\n" @@ -836,6 +1026,19 @@ def __repr__(self): output += "--------------------------\n" for name in self.pipeline_block.intermediates_inputs: output += f"{name}: \n" + + + if hasattr(block, "intermediates_outputs"): + output += "\nIntermediate outputs:\n" + output += "--------------------------\n" + output += f"{', '.join(block.intermediates_outputs)}\n\n" + + # Add final intermediate outputs section at the bottom + if hasattr(block, "final_intermediates_outputs"): + output += "Final intermediate outputs:\n" + output += "--------------------------\n" + output += f"{', '.join(block.final_intermediates_outputs)}\n" + return output # YiYi TO-DO: try to unify the to method with the one in DiffusionPipeline diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index e4b4295ef7d1..36476a1909a5 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -131,9 +131,12 @@ class StableDiffusionXLInputStep(PipelineBlock): def inputs(self) -> List[Tuple[str, Any]]: return [ ("prompt", None), - ("prompt_embeds", None), ] + @property + def intermediates_inputs(self) -> List[str]: + return ["prompt_embeds"] + @property def intermediates_outputs(self) -> List[str]: return ["batch_size"] @@ -141,7 +144,7 @@ def intermediates_outputs(self) -> List[str]: @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: prompt = state.get_input("prompt") - prompt_embeds = state.get_input("prompt_embeds") + prompt_embeds = state.get_intermediate("prompt_embeds") if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -168,15 +171,15 @@ def inputs(self) -> List[Tuple[str, Any]]: ("negative_prompt", None), ("negative_prompt_2", None), ("cross_attention_kwargs", None), - ("prompt_embeds", None), - ("negative_prompt_embeds", None), - ("pooled_prompt_embeds", None), - ("negative_pooled_prompt_embeds", None), ("num_images_per_prompt", 1), ("guidance_scale", 5.0), ("clip_skip", None), ] + @property + def intermediates_inputs(self) -> List[str]: + return ["prompt_embeds", "negative_prompt_embeds", "pooled_prompt_embeds", "negative_pooled_prompt_embeds"] + @property def intermediates_outputs(self) -> List[str]: return [ @@ -263,14 +266,15 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: negative_prompt = state.get_input("negative_prompt") negative_prompt_2 = state.get_input("negative_prompt_2") cross_attention_kwargs = state.get_input("cross_attention_kwargs") - prompt_embeds = state.get_input("prompt_embeds") - negative_prompt_embeds = state.get_input("negative_prompt_embeds") - pooled_prompt_embeds = state.get_input("pooled_prompt_embeds") - negative_pooled_prompt_embeds = state.get_input("negative_pooled_prompt_embeds") num_images_per_prompt = state.get_input("num_images_per_prompt") guidance_scale = state.get_input("guidance_scale") clip_skip = state.get_input("clip_skip") + prompt_embeds = state.get_intermediate("prompt_embeds") + negative_prompt_embeds = state.get_intermediate("negative_prompt_embeds") + pooled_prompt_embeds = state.get_intermediate("pooled_prompt_embeds") + negative_pooled_prompt_embeds = state.get_intermediate("negative_pooled_prompt_embeds") + do_classifier_free_guidance = guidance_scale > 1.0 device = pipeline._execution_device @@ -335,7 +339,7 @@ def inputs(self) -> List[Tuple[str, Any]]: @property def intermediates_inputs(self) -> List[str]: - return ["batch_size", "dtype","preprocess_kwargs"] + return ["batch_size", "dtype", "preprocess_kwargs"] @property def intermediates_outputs(self) -> List[str]: @@ -1771,8 +1775,8 @@ class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks): TEXT2IMAGE_BLOCKS = OrderedDict([ - ("input", StableDiffusionXLInputStep), ("text_encoder", StableDiffusionXLTextEncoderStep), + ("input", StableDiffusionXLInputStep), ("set_timesteps", StableDiffusionXLAutoSetTimestepsStep), ("prepare_latents", StableDiffusionXLAutoPrepareLatentsStep), ("prepare_add_cond", StableDiffusionXLAutoPrepareAdditionalConditioningStep), @@ -1781,8 +1785,8 @@ class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks): ]) IMAGE2IMAGE_BLOCKS = OrderedDict([ - ("input", StableDiffusionXLInputStep), ("text_encoder", StableDiffusionXLTextEncoderStep), + ("input", StableDiffusionXLInputStep), ("image_encoder", StableDiffusionXLVaeEncoderStep), ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep), @@ -1792,8 +1796,8 @@ class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks): ]) INPAINT_BLOCKS = OrderedDict([ - ("input", StableDiffusionXLInputStep), ("text_encoder", StableDiffusionXLTextEncoderStep), + ("input", StableDiffusionXLInputStep), ("image_encoder", StableDiffusionXLInpaintVaeEncoderStep), ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), ("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep), @@ -1807,8 +1811,8 @@ class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks): ]) AUTO_BLOCKS = OrderedDict([ - ("input", StableDiffusionXLInputStep), ("text_encoder", StableDiffusionXLTextEncoderStep), + ("input", StableDiffusionXLInputStep), ("image_encoder", StableDiffusionXLAutoVaeEncoderStep), ("set_timesteps", StableDiffusionXLAutoSetTimestepsStep), ("prepare_latents", StableDiffusionXLAutoPrepareLatentsStep), From 0966663d2a0522b10cb3276b0df03460f68534d2 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 11 Jan 2025 19:15:54 +0100 Subject: [PATCH 044/170] adjust print --- src/diffusers/pipelines/modular_pipeline.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index 30e237d9c675..8544e5363499 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -669,8 +669,9 @@ def __repr__(self): if self.trigger_inputs: header += "\n" # Add empty line before header += " " + "=" * 100 + "\n" # Add decorative line - header += " This pipeline block contains AutoPipelineBlocks where different blocks are dispatched at runtime based on your inputs.\n" - header += " You can use `get_triggered_blocks(input1, input2,...)` to get specific information for your trigger inputs.\n" + header += " This pipeline block contains dynamic blocks that are selected at runtime based on your inputs.\n" + header += " You can use `get_triggered_blocks(input1, input2,...)` to see which blocks will be used for your trigger inputs.\n" + header += " Use `get_triggered_blocks()` to see blocks will be used for default inputs (when no trigger inputs are provided)\n" header += f" Trigger Inputs: {self.trigger_inputs}\n" header += " " + "=" * 100 + "\n" # Add decorative line header += "\n" # Add empty line after @@ -980,6 +981,7 @@ def __repr__(self): output += f" This pipeline contains dynamic blocks that are selected at runtime based on your inputs.\n" output += f" • Trigger inputs: {block.trigger_inputs}\n" output += f" • Use .pipeline_block.get_triggered_blocks(*inputs) to see which blocks will be used for specific inputs\n" + output += f" • Use .pipeline_block.get_triggered_blocks() to see blocks will be used for default inputs (when no trigger inputs are provided)\n" output += "\n" output += "Pipeline Block:\n" From 7f897a9fc4a00f5b24b88e7125a606931c30f86e Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 12 Jan 2025 04:50:45 +0100 Subject: [PATCH 045/170] fix --- src/diffusers/pipelines/modular_pipeline.py | 12 ++++- .../pipeline_stable_diffusion_xl_modular.py | 54 ++++++++++--------- 2 files changed, 39 insertions(+), 27 deletions(-) diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index 8544e5363499..a13f0dccda54 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -82,9 +82,9 @@ def to_dict(self) -> Dict[str, Any]: def __repr__(self): def format_value(v): if hasattr(v, "shape") and hasattr(v, "dtype"): - return f"Tensor(\n dtype={v.dtype}, shape={v.shape}\n {v})" + return f"Tensor(dtype={v.dtype}, shape={v.shape})" elif isinstance(v, list) and len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): - return f"[Tensor(\n dtype={v[0].dtype}, shape={v[0].shape}\n {v[0]}), ...]" + return f"[Tensor(dtype={v[0].dtype}, shape={v[0].shape}), ...]" else: return repr(v) @@ -238,6 +238,10 @@ def __init__(self): if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)): raise ValueError(f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same.") default_blocks = [t for t in self.block_trigger_inputs if t is None] + # can only have 1 or 0 default block, and has to put in the last + # the order of blocksmatters here because the first block with matching trigger will be dispatched + # e.g. blocks = [inpaint, img2img] and block_trigger_inputs = ["mask", "image"] + # if both mask and image are provided, it is inpaint; if only image is provided, it is img2img if len(default_blocks) > 1 or ( len(default_blocks) == 1 and self.block_trigger_inputs[-1] is not None ): @@ -248,6 +252,7 @@ def __init__(self): # Map trigger inputs to block objects self.trigger_to_block_map = dict(zip(self.block_trigger_inputs, self.blocks.values())) + self.trigger_to_block_name_map = dict(zip(self.block_trigger_inputs, self.blocks.keys())) self.block_to_trigger_map = dict(zip(self.blocks.keys(), self.block_trigger_inputs)) @property @@ -324,6 +329,9 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: if input_name is not None and state.get_input(input_name) is not None: block = self.trigger_to_block_map[input_name] break + elif input_name is not None and state.get_intermediate(input_name) is not None: + block = self.trigger_to_block_map[input_name] + break if block is None: logger.warning(f"skipping auto block: {self.__class__.__name__}") diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 35860186339f..4b2b33c2ff70 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -1112,6 +1112,7 @@ def intermediates_inputs(self) -> List[str]: "mask", # inpainting "masked_image_latents", # inpainting "noise", # inpainting + "image_latents", # inpainting ] @property @@ -2028,25 +2029,24 @@ class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks): block_trigger_inputs = ["mask_image", "image"] -class StableDiffusionXLAutoSetTimestepsStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLSetTimestepsStep] - block_names = ["img2img", "text2img"] - block_trigger_inputs = ["image", None] +class StableDiffusionXLBeforeDenoiseStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLInputStep, StableDiffusionXLSetTimestepsStep, StableDiffusionXLPrepareLatentsStep, StableDiffusionXLPrepareAdditionalConditioningStep] + block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"] +class StableDiffusionXLImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLInputStep, StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLImg2ImgPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep] + block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"] -class StableDiffusionXLAutoPrepareLatentsStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLInpaintPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareLatentsStep, StableDiffusionXLPrepareLatentsStep] - block_names = ["inpaint","img2img", "text2img"] - block_trigger_inputs = ["mask_image", "image", None] +class StableDiffusionXLInpaintBeforeDenoiseStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLInputStep, StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLInpaintPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep] + block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"] -class StableDiffusionXLAutoPrepareAdditionalConditioningStep(AutoPipelineBlocks): - block_classes = [ - StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, - StableDiffusionXLPrepareAdditionalConditioningStep, - ] - block_names = ["img2img", "text2img"] - block_trigger_inputs = ["image", None] +class StableDiffusionXLAutoBeforeDenoiseStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintBeforeDenoiseStep, StableDiffusionXLImg2ImgBeforeDenoiseStep, StableDiffusionXLBeforeDenoiseStep] + block_names = ["inpaint", "img2img", "text2img"] + block_trigger_inputs = ["mask", "image_latents", None] + class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): @@ -2064,10 +2064,10 @@ class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks): TEXT2IMAGE_BLOCKS = OrderedDict([ ("text_encoder", StableDiffusionXLTextEncoderStep), ("input", StableDiffusionXLInputStep), - ("set_timesteps", StableDiffusionXLAutoSetTimestepsStep), - ("prepare_latents", StableDiffusionXLAutoPrepareLatentsStep), - ("prepare_add_cond", StableDiffusionXLAutoPrepareAdditionalConditioningStep), - ("denoise", StableDiffusionXLAutoDenoiseStep), + ("set_timesteps", StableDiffusionXLSetTimestepsStep), + ("prepare_latents", StableDiffusionXLPrepareLatentsStep), + ("prepare_add_cond", StableDiffusionXLPrepareAdditionalConditioningStep), + ("denoise", StableDiffusionXLDenoiseStep), ("decode", StableDiffusionXLDecodeStep) ]) @@ -2099,11 +2099,8 @@ class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks): AUTO_BLOCKS = OrderedDict([ ("text_encoder", StableDiffusionXLTextEncoderStep), - ("input", StableDiffusionXLInputStep), ("image_encoder", StableDiffusionXLAutoVaeEncoderStep), - ("set_timesteps", StableDiffusionXLAutoSetTimestepsStep), - ("prepare_latents", StableDiffusionXLAutoPrepareLatentsStep), - ("prepare_add_cond", StableDiffusionXLAutoPrepareAdditionalConditioningStep), + ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), ("denoise", StableDiffusionXLAutoDenoiseStep), ("decode", StableDiffusionXLAutoDecodeStep) ]) @@ -2138,11 +2135,18 @@ def vae_scale_factor(self): vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) return vae_scale_factor + @property + def num_channels_unet(self): + num_channels_unet = 4 + if hasattr(self, "unet") and self.unet is not None: + num_channels_unet = self.unet.config.in_channels + return num_channels_unet + @property def num_channels_latents(self): num_channels_latents = 4 - if hasattr(self, "unet") and self.unet is not None: - num_channels_latents = self.unet.config.in_channels + if hasattr(self, "vae") and self.vae is not None: + num_channels_latents = self.vae.config.latent_channels return num_channels_latents # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids From a6804de4a2073d974e03b7d3a75c4ed2d46db59e Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 12 Jan 2025 16:24:01 +0100 Subject: [PATCH 046/170] add controlnet union to auto & fix for pag --- .../pipeline_stable_diffusion_xl_modular.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 4b2b33c2ff70..61c50012f1a6 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -1598,7 +1598,7 @@ def inputs(self) -> List[Tuple[str, Any]]: ("control_guidance_start", 0.0), ("control_guidance_end", 1.0), ("controlnet_conditioning_scale", 1.0), - ("control_mode", 0), + ("control_mode", None), ("guess_mode", False), ("num_images_per_prompt", 1), ("guidance_scale", 5.0), @@ -1791,8 +1791,9 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: control_type = ( control_type.reshape(1, -1) .to(device, dtype=prompt_embeds.dtype) - .repeat(batch_size * num_images_per_prompt * 2, 1) ) + control_type = pipeline.controlnet_guider.prepare_input(control_type, control_type) + with pipeline.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # prepare latents for unet using the guider @@ -2050,9 +2051,9 @@ class StableDiffusionXLAutoBeforeDenoiseStep(AutoPipelineBlocks): class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLControlNetDenoiseStep, StableDiffusionXLDenoiseStep] - block_names = ["controlnet", "unet"] - block_trigger_inputs = ["control_image", None] + block_classes = [StableDiffusionXLControlNetUnionDenoiseStep, StableDiffusionXLControlNetDenoiseStep, StableDiffusionXLDenoiseStep] + block_names = ["controlnet_union", "controlnet", "unet"] + block_trigger_inputs = ["control_mode", "control_image", None] class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks): From 7007f72409ce6eaf2519a606f1e991c6f276bfa4 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 16 Jan 2025 11:44:24 +0100 Subject: [PATCH 047/170] InputParam, OutputParam, get_auto_doc --- src/diffusers/pipelines/modular_pipeline.py | 634 ++++++++++++++---- .../pipeline_stable_diffusion_xl_modular.py | 466 +++++++------ 2 files changed, 775 insertions(+), 325 deletions(-) diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index a13f0dccda54..8a0ae02aa385 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -20,6 +20,7 @@ import torch from tqdm.auto import tqdm +import re from ..configuration_utils import ConfigMixin from ..utils import ( @@ -100,6 +101,249 @@ def format_value(v): f")" ) +@dataclass +class InputParam: + name: str + default: Any = None + required: bool = False + description: str = "" + type_hint: Any = Any + + def __repr__(self): + return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>" + +@dataclass +class OutputParam: + name: str + description: str = "" + type_hint: Any = Any + + def __repr__(self): + return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>" + +def format_inputs_short(inputs): + """ + Format input parameters into a string representation, with required params first followed by optional ones. + + Args: + inputs: List of input parameters with 'required' and 'name' attributes, and 'default' for optional params + + Returns: + str: Formatted string of input parameters + """ + required_inputs = [param for param in inputs if param.required] + optional_inputs = [param for param in inputs if not param.required] + + required_str = ", ".join(param.name for param in required_inputs) + optional_str = ", ".join(f"{param.name}={param.default}" for param in optional_inputs) + + inputs_str = required_str + if optional_str: + inputs_str = f"{inputs_str}, {optional_str}" if required_str else optional_str + + return inputs_str + + +def format_intermediates_short(block) -> str: + """ + Formats intermediate inputs and outputs of a block into a string representation. + + Args: + block: Pipeline block with potential intermediates + + Returns: + str: Formatted string like "input1, Required(input2) -> output1, output2" + """ + # Handle inputs + input_parts = [] + if hasattr(block, "intermediates_inputs"): + for inp in block.intermediates_inputs: + parts = [] + # Check if input is required + if hasattr(block, "required_intermediates_inputs") and inp.name in block.required_intermediates_inputs: + parts.append("Required") + + # Get base name or modified name + name = inp.name + if hasattr(block, "intermediates_outputs") and name in {out.name for out in block.intermediates_outputs}: + name = f"*{name}" + + # Combine Required() wrapper with possibly starred name + if parts: + input_parts.append(f"Required({name})") + else: + input_parts.append(name) + + # Handle outputs + output_parts = [] + if hasattr(block, "intermediates_outputs"): + outputs = [out.name for out in block.intermediates_outputs] + if hasattr(block, "intermediates_inputs"): + # Only show new outputs if we have inputs + inputs_set = {inp.name for inp in block.intermediates_inputs} + outputs = [out for out in outputs if out not in inputs_set] + output_parts.extend(outputs) + + # Combine with arrow notation if both inputs and outputs exist + if input_parts and output_parts: + return f"{', '.join(input_parts)} -> {', '.join(output_parts)}" + elif input_parts: + return ', '.join(input_parts) + elif output_parts: + return ', '.join(output_parts) + return "" + + +def format_input_params(input_params: List[InputParam], indent_level: int = 4, max_line_length: int = 115) -> str: + """Format a list of InputParam objects into a readable string representation. + + Args: + input_params: List of InputParam objects to format + indent_level: Number of spaces to indent each parameter line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + + Returns: + A formatted string representing all input parameters + """ + if not input_params: + return "" + + base_indent = " " * indent_level + param_indent = " " * (indent_level + 4) + desc_indent = " " * (indent_level + 8) + formatted_params = [] + + def get_type_str(type_hint): + if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Union: + types = [t.__name__ if hasattr(t, "__name__") else str(t) for t in type_hint.__args__] + return f"Union[{', '.join(types)}]" + return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint) + + def wrap_text(text: str, indent: str, max_length: int) -> str: + """Wrap text while preserving markdown links and maintaining indentation.""" + words = text.split() + lines = [] + current_line = [] + current_length = 0 + + for word in words: + # Calculate word length including space + word_length = len(word) + (1 if current_line else 0) + + # Check if adding this word would exceed the max length + if current_line and current_length + word_length > max_length: + lines.append(" ".join(current_line)) + current_line = [word] + current_length = len(word) + else: + current_line.append(word) + current_length += word_length + + if current_line: + lines.append(" ".join(current_line)) + + # Join lines with proper indentation + return f"\n{indent}".join(lines) + + # Add the "Args:" header + formatted_params.append(f"{base_indent}Args:") + + for param in input_params: + # Format parameter name and type + type_str = get_type_str(param.type_hint) if param.type_hint != Any else "" + param_str = f"{param_indent}{param.name} (`{type_str}`" + + # Add optional tag and default value if parameter is optional + if not param.required: + param_str += ", *optional*" + if param.default is not None: + param_str += f", defaults to {param.default}" + param_str += "):" + + # Add description on a new line with additional indentation and wrapping + if param.description: + desc = re.sub( + r'\[(.*?)\]\((https?://[^\s\)]+)\)', + r'[\1](\2)', + param.description + ) + wrapped_desc = wrap_text(desc, desc_indent, max_line_length) + param_str += f"\n{desc_indent}{wrapped_desc}" + + formatted_params.append(param_str) + + return "\n\n".join(formatted_params) + + +def format_output_params(output_params: List[OutputParam], indent_level: int = 4, max_line_length: int = 115) -> str: + """Format a list of OutputParam objects into a readable string representation. + + Args: + output_params: List of OutputParam objects to format + indent_level: Number of spaces to indent each parameter line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + + Returns: + A formatted string representing all output parameters + """ + if not output_params: + return "" + + base_indent = " " * indent_level + param_indent = " " * (indent_level + 4) + desc_indent = " " * (indent_level + 8) + formatted_params = [] + + def get_type_str(type_hint): + if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Union: + types = [t.__name__ if hasattr(t, "__name__") else str(t) for t in type_hint.__args__] + return f"Union[{', '.join(types)}]" + return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint) + + def wrap_text(text: str, indent: str, max_length: int) -> str: + """Wrap text while preserving markdown links and maintaining indentation.""" + words = text.split() + lines = [] + current_line = [] + current_length = 0 + + for word in words: + word_length = len(word) + (1 if current_line else 0) + + if current_line and current_length + word_length > max_length: + lines.append(" ".join(current_line)) + current_line = [word] + current_length = len(word) + else: + current_line.append(word) + current_length += word_length + + if current_line: + lines.append(" ".join(current_line)) + + return f"\n{indent}".join(lines) + + # Add the "Returns:" header + formatted_params.append(f"{base_indent}Returns:") + + for param in output_params: + # Format parameter name and type + type_str = get_type_str(param.type_hint) if param.type_hint != Any else "" + param_str = f"{param_indent}{param.name} (`{type_str}`):" + + # Add description on a new line with additional indentation and wrapping + if param.description: + desc = re.sub( + r'\[(.*?)\]\((https?://[^\s\)]+)\)', + r'[\1](\2)', + param.description + ) + wrapped_desc = wrap_text(desc, desc_indent, max_line_length) + param_str += f"\n{desc_indent}{wrapped_desc}" + + formatted_params.append(param_str) + + return "\n\n".join(formatted_params) class PipelineBlock: # YiYi Notes: do we need this? @@ -109,18 +353,33 @@ class PipelineBlock: model_name = None @property - def inputs(self) -> Tuple[Tuple[str, Any], ...]: - # (input_name, default_value) - return () + def inputs(self) -> List[InputParam]: + return [] @property - def intermediates_inputs(self) -> List[str]: + def intermediates_inputs(self) -> List[InputParam]: return [] @property - def intermediates_outputs(self) -> List[str]: + def intermediates_outputs(self) -> List[OutputParam]: return [] + @property + def required_inputs(self) -> List[str]: + input_names = [] + for input_param in self.inputs: + if input_param.required: + input_names.append(input_param.name) + return input_names + + @property + def required_intermediates_inputs(self) -> List[str]: + input_names = [] + for input_param in self.intermediates_inputs: + if input_param.required: + input_names.append(input_param.name) + return input_names + def __init__(self): self.components: Dict[str, Any] = {} self.auxiliaries: Dict[str, Any] = {} @@ -161,19 +420,12 @@ def __repr__(self): ) # Inputs section - inputs = "inputs: " + ", ".join( - f"{name}={default}" if default is not None else name - for name, default in self.inputs - ) + inputs_str = format_inputs_short(self.inputs) + inputs = "Inputs:\n " + inputs_str # Intermediates section - input_set = set(self.intermediates_inputs) - output_set = set(self.intermediates_outputs) - - modified_inputs = [f"{item}*" for item in self.intermediates_inputs] - new_outputs = [item for item in self.intermediates_outputs if item not in input_set] - - intermediates = f"intermediates: {', '.join(modified_inputs)} -> {', '.join(new_outputs)}" + intermediates_str = format_intermediates_short(self) + intermediates = f"Intermediates:\n {intermediates_str}" return ( f"{class_name}(\n" @@ -185,35 +437,91 @@ def __repr__(self): f")" ) + def get_doc_string(self): + """ + Generates a formatted documentation string describing the pipeline block's parameters and structure. + + Returns: + str: A formatted string containing information about call parameters, intermediate inputs/outputs, + and final intermediate outputs. + """ + output = "Call Parameters:\n" + output += "------------------------\n" + output += format_input_params(self.inputs, indent_level=2) -def combine_inputs(*named_input_lists: List[Tuple[str, List[Tuple[str, Any]]]]) -> List[Tuple[str, Any]]: + output += "\n\nIntermediate inputs:\n" + output += "--------------------------\n" + output += format_input_params(self.intermediates_inputs, indent_level=2) + + if hasattr(self, "intermediates_outputs"): + output += "\n\nIntermediate outputs:\n" + output += "--------------------------\n" + output += format_output_params(self.intermediates_outputs, indent_level=2) + + if hasattr(self, "final_intermediates_outputs"): + output += "\nFinal intermediate outputs:\n" + output += "--------------------------\n" + output += format_output_params(self.final_intermediates_outputs, indent_level=2) + + return output + + + +def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]: """ - Combines multiple lists of (name, default_value) tuples from different blocks. For duplicate inputs, updates only if - current value is None and new value is not None. Warns if multiple non-None default values exist for the same input. + Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if + current default value is None and new default value is not None. Warns if multiple non-None default values + exist for the same input. Args: - named_input_lists: List of tuples containing (block_name, input_list) pairs + named_input_lists: List of tuples containing (block_name, input_param_list) pairs + + Returns: + List[InputParam]: Combined list of unique InputParam objects """ - combined_dict = {} - value_sources = {} + combined_dict = {} # name -> InputParam + value_sources = {} # name -> block_name for block_name, inputs in named_input_lists: - for name, value in inputs: - if name in combined_dict: - current_value = combined_dict[name] - if current_value is not None and value is not None and current_value != value: + for input_param in inputs: + if input_param.name in combined_dict: + current_param = combined_dict[input_param.name] + if (current_param.default is not None and + input_param.default is not None and + current_param.default != input_param.default): warnings.warn( - f"Multiple different default values found for input '{name}': " - f"{current_value} (from block '{value_sources[name]}') and " - f"{value} (from block '{block_name}'). Using {current_value}." + f"Multiple different default values found for input '{input_param.name}': " + f"{current_param.default} (from block '{value_sources[input_param.name]}') and " + f"{input_param.default} (from block '{block_name}'). Using {current_param.default}." ) - if current_value is None and value is not None: - combined_dict[name] = value - value_sources[name] = block_name + if current_param.default is None and input_param.default is not None: + combined_dict[input_param.name] = input_param + value_sources[input_param.name] = block_name else: - combined_dict[name] = value - value_sources[name] = block_name - return list(combined_dict.items()) + combined_dict[input_param.name] = input_param + value_sources[input_param.name] = block_name + + return list(combined_dict.values()) + +def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> List[OutputParam]: + """ + Combines multiple lists of OutputParam objects from different blocks. For duplicate outputs, + keeps the first occurrence of each output name. + + Args: + named_output_lists: List of tuples containing (block_name, output_param_list) pairs + + Returns: + List[OutputParam]: Combined list of unique OutputParam objects + """ + combined_dict = {} # name -> OutputParam + + for block_name, outputs in named_output_lists: + for output_param in outputs: + if output_param.name not in combined_dict: + combined_dict[output_param.name] = output_param + + return list(combined_dict.values()) class AutoPipelineBlocks: @@ -307,18 +615,62 @@ def configs(self): configs.update(block.configs) return configs + @property + def required_inputs(self) -> List[str]: + first_block = next(iter(self.blocks.values())) + required_by_all = set(getattr(first_block, "required_inputs", set())) + + # Intersect with required inputs from all other blocks + for block in list(self.blocks.values())[1:]: + block_required = set(getattr(block, "required_inputs", set())) + required_by_all.intersection_update(block_required) + + return list(required_by_all) + + @property + def required_intermediates_inputs(self) -> List[str]: + first_block = next(iter(self.blocks.values())) + required_by_all = set(getattr(first_block, "required_intermediates_inputs", set())) + + # Intersect with required inputs from all other blocks + for block in list(self.blocks.values())[1:]: + block_required = set(getattr(block, "required_intermediates_inputs", set())) + required_by_all.intersection_update(block_required) + + return list(required_by_all) + + + # YiYi TODO: add test for this @property def inputs(self) -> List[Tuple[str, Any]]: named_inputs = [(name, block.inputs) for name, block in self.blocks.items()] - return combine_inputs(*named_inputs) + combined_inputs = combine_inputs(*named_inputs) + # mark Required inputs only if that input is required by all the blocks + for input_param in combined_inputs: + if input_param.name in self.required_inputs: + input_param.required = True + else: + input_param.required = False + return combined_inputs + @property def intermediates_inputs(self) -> List[str]: - return list(set().union(*(block.intermediates_inputs for block in self.blocks.values()))) + named_inputs = [(name, block.intermediates_inputs) for name, block in self.blocks.items()] + combined_inputs = combine_inputs(*named_inputs) + # mark Required inputs only if that input is required by all the blocks + for input_param in combined_inputs: + if input_param.name in self.required_intermediates_inputs: + input_param.required = True + else: + input_param.required = False + return combined_inputs @property def intermediates_outputs(self) -> List[str]: - return list(set().union(*(block.intermediates_outputs for block in self.blocks.values()))) + named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] + combined_outputs = combine_outputs(*named_outputs) + return combined_outputs @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: @@ -417,24 +769,11 @@ def __repr__(self): sections.append(f" Block: {block.__class__.__name__}") - if hasattr(block, "inputs"): - inputs_str = ", ".join( - name if default is None else f"{name}={default}" - for name, default in block.inputs - ) - if inputs_str: - sections.append(f" inputs: {inputs_str}") - - if hasattr(block, "intermediates_inputs") or hasattr(block, "intermediates_outputs"): - intermediates_str = "" - if hasattr(block, "intermediates_inputs"): - intermediates_str += f"{', '.join(block.intermediates_inputs)}" - if hasattr(block, "intermediates_outputs"): - if intermediates_str: - intermediates_str += " -> " - intermediates_str += f"{', '.join(block.intermediates_outputs)}" - if intermediates_str: - sections.append(f" intermediates: {intermediates_str}") + inputs_str = format_inputs_short(block.inputs) + sections.append(f" inputs:\n {inputs_str}") + + intermediates_str = f" intermediates:\n {format_intermediates_short(block)}" + sections.append(intermediates_str) sections.append("") @@ -445,6 +784,33 @@ def __repr__(self): f")" ) + def get_doc_string(self): + """ + Generates a formatted documentation string describing the pipeline block's parameters and structure. + + Returns: + str: A formatted string containing information about call parameters, intermediate inputs/outputs, + and final intermediate outputs. + """ + output = "Call Parameters:\n" + output += "------------------------\n" + output += format_input_params(self.inputs, indent_level=2) + + output += "\n\nIntermediate inputs:\n" + output += "--------------------------\n" + output += format_input_params(self.intermediates_inputs, indent_level=2) + + if hasattr(self, "intermediates_outputs"): + output += "\n\nIntermediate outputs:\n" + output += "--------------------------\n" + output += format_output_params(self.intermediates_outputs, indent_level=2) + + if hasattr(self, "final_intermediates_outputs"): + output += "\nFinal intermediate outputs:\n" + output += "--------------------------\n" + output += format_output_params(self.final_intermediates_outputs, indent_level=2) + + return output class SequentialPipelineBlocks: """ @@ -527,28 +893,60 @@ def configs(self): configs.update(block.configs) return configs + @property + def required_inputs(self) -> List[str]: + # Get the first block from the dictionary + first_block = next(iter(self.blocks.values())) + required_by_any = set(getattr(first_block, "required_inputs", set())) + + # Union with required inputs from all other blocks + for block in list(self.blocks.values())[1:]: + block_required = set(getattr(block, "required_inputs", set())) + required_by_any.update(block_required) + + return list(required_by_any) + + @property + def required_intermediates_inputs(self) -> List[str]: + required_intermediates_inputs = [] + for input_param in self.intermediates_inputs: + if input_param.required: + required_intermediates_inputs.append(input_param.name) + return required_intermediates_inputs + + # YiYi TODO: add test for this @property def inputs(self) -> List[Tuple[str, Any]]: named_inputs = [(name, block.inputs) for name, block in self.blocks.items()] - return combine_inputs(*named_inputs) + combined_inputs = combine_inputs(*named_inputs) + # mark Required inputs only if that input is required any of the blocks + for input_param in combined_inputs: + if input_param.name in self.required_inputs: + input_param.required = True + else: + input_param.required = False + return combined_inputs @property def intermediates_inputs(self) -> List[str]: - inputs = set() + inputs = [] outputs = set() # Go through all blocks in order for block in self.blocks.values(): # Add inputs that aren't in outputs yet - inputs.update(input_name for input_name in block.intermediates_inputs if input_name not in outputs) + inputs.extend(input_name for input_name in block.intermediates_inputs if input_name.name not in outputs) # Add this block's outputs - outputs.update(block.intermediates_outputs) + block_intermediates_outputs = [out.name for out in block.intermediates_outputs] + outputs.update(block_intermediates_outputs) - return list(inputs) + return inputs @property def intermediates_outputs(self) -> List[str]: - return list(set().union(*(block.intermediates_outputs for block in self.blocks.values()))) + named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] + combined_outputs = combine_outputs(*named_outputs) + return combined_outputs @property def final_intermediates_outputs(self) -> List[str]: @@ -710,44 +1108,28 @@ def __repr__(self): for i, (name, block) in enumerate(self.blocks.items()): blocks_str += f" {i}. {name} ({block.__class__.__name__})\n" - if hasattr(block, "inputs"): - inputs_str = ", ".join( - name if default is None else f"{name}={default}" - for name, default in block.inputs - ) - blocks_str += f" inputs: {inputs_str}\n" - - if hasattr(block, "intermediates_inputs") or hasattr(block, "intermediates_outputs"): - intermediates_str = "" - if hasattr(block, "intermediates_inputs"): - inputs_set = set(block.intermediates_inputs) - intermediates_str += ", ".join(f"*{inp}" if inp in (getattr(block, "intermediates_outputs", set())) else inp - for inp in block.intermediates_inputs) - - if hasattr(block, "intermediates_outputs"): - if intermediates_str: - intermediates_str += " -> " - outputs_set = set(block.intermediates_outputs) - new_outputs = outputs_set - inputs_set if hasattr(block, "intermediates_inputs") else outputs_set - intermediates_str += ", ".join(new_outputs) - - if intermediates_str: - blocks_str += f" intermediates: {intermediates_str}\n" + inputs_str = format_inputs_short(block.inputs) + + blocks_str += f" inputs: {inputs_str}\n" + + intermediates_str = format_intermediates_short(block) + + if intermediates_str: + blocks_str += f" intermediates: {intermediates_str}\n" blocks_str += "\n" - - inputs_str = " inputs:\n " + ", ".join( - f"{name}={default}" if default is not None else f"{name}" - for name, default in self.inputs - ) - - modified_inputs = [f"*{inp}" if inp in self.intermediates_outputs else inp for inp in self.intermediates_inputs] - new_outputs = [out for out in self.intermediates_outputs if out not in self.intermediates_inputs] + inputs_str = format_inputs_short(self.inputs) + inputs_str = " Inputs:\n " + inputs_str + final_intermediates_outputs = [out.name for out in self.final_intermediates_outputs] + + intermediates_str_short = format_intermediates_short(self) + intermediates_input_str = intermediates_str_short.split('->')[0].strip() # "Required(latents), crops_coords" + intermediates_output_str = intermediates_str_short.split('->')[1].strip() intermediates_str = ( "\n Intermediates:\n" - f" - inputs: {', '.join(modified_inputs)}\n" - f" - outputs: {', '.join(new_outputs)}\n" - f" - final outputs: {', '.join(self.final_intermediates_outputs)}" + f" - inputs: {intermediates_input_str}\n" + f" - outputs: {intermediates_output_str}\n" + f" - final outputs: {', '.join(final_intermediates_outputs)}" ) return ( @@ -761,6 +1143,33 @@ def __repr__(self): f")" ) + def get_doc_string(self): + """ + Generates a formatted documentation string describing the pipeline block's parameters and structure. + + Returns: + str: A formatted string containing information about call parameters, intermediate inputs/outputs, + and final intermediate outputs. + """ + output = "Call Parameters:\n" + output += "------------------------\n" + output += format_input_params(self.inputs, indent_level=2) + + output += "\n\nIntermediate inputs:\n" + output += "--------------------------\n" + output += format_input_params(self.intermediates_inputs, indent_level=2) + + if hasattr(self, "intermediates_outputs"): + output += "\n\nIntermediate outputs:\n" + output += "--------------------------\n" + output += format_output_params(self.intermediates_outputs, indent_level=2) + + if hasattr(self, "final_intermediates_outputs"): + output += "\nFinal intermediate outputs:\n" + output += "--------------------------\n" + output += format_output_params(self.final_intermediates_outputs, indent_level=2) + + return output class ModularPipeline(ConfigMixin): """ @@ -894,16 +1303,17 @@ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = # Add inputs to state, using defaults if not provided in the kwargs or the state # if same input already in the state, will override it if provided in the kwargs + intermediates_inputs = [inp.name for inp in self.pipeline_block.intermediates_inputs] for name, default in default_params.items(): if name in input_params: - if name not in self.pipeline_block.intermediates_inputs: + if name not in intermediates_inputs: state.add_input(name, input_params.pop(name)) else: state.add_input(name, input_params[name]) elif name not in state.inputs: state.add_input(name, default) - for name in self.pipeline_block.intermediates_inputs: + for name in intermediates_inputs: if name in input_params: state.add_intermediate(name, input_params.pop(name)) @@ -973,8 +1383,8 @@ def update_states(self, **kwargs): @property def default_call_parameters(self) -> Dict[str, Any]: params = {} - for name, default in self.pipeline_block.inputs: - params[name] = default + for input_param in self.pipeline_block.inputs: + params[input_param.name] = input_param.default return params def __repr__(self): @@ -1026,28 +1436,8 @@ def __repr__(self): output += f"{name}: {config!r}\n" output += "\n" - # List the default call parameters - output += "Call Parameters:\n" - output += "------------------------\n" - for name, default in self.default_call_parameters.items(): - output += f"{name}: {default!r}\n" - - output += "\nIntermediate inputs:\n" - output += "--------------------------\n" - for name in self.pipeline_block.intermediates_inputs: - output += f"{name}: \n" - - - if hasattr(block, "intermediates_outputs"): - output += "\nIntermediate outputs:\n" - output += "--------------------------\n" - output += f"{', '.join(block.intermediates_outputs)}\n\n" - - # Add final intermediate outputs section at the bottom - if hasattr(block, "final_intermediates_outputs"): - output += "Final intermediate outputs:\n" - output += "--------------------------\n" - output += f"{', '.join(block.final_intermediates_outputs)}\n" + # List the call parameters + output += self.pipeline_block.get_doc_string() return output diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 61c50012f1a6..4256ffa3d463 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -38,6 +38,8 @@ ModularPipeline, PipelineBlock, PipelineState, + InputParam, + OutputParam, SequentialPipelineBlocks, ) from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin @@ -128,18 +130,22 @@ class StableDiffusionXLInputStep(PipelineBlock): model_name = "stable-diffusion-xl" @property - def inputs(self) -> List[Tuple[str, Any]]: + def inputs(self) -> List[InputParam]: return [ - ("prompt", None), + InputParam( + name="prompt", + description="The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds` instead.", + type_hint=Union[str, List[str]], + ), ] @property def intermediates_inputs(self) -> List[str]: - return ["prompt_embeds"] + return [InputParam("prompt_embeds")] @property def intermediates_outputs(self) -> List[str]: - return ["batch_size"] + return [OutputParam("batch_size")] @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: @@ -164,30 +170,63 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock): model_name = "stable-diffusion-xl" @property - def inputs(self) -> List[Tuple[str, Any]]: + def inputs(self) -> List[InputParam]: return [ - ("prompt", None), - ("prompt_2", None), - ("negative_prompt", None), - ("negative_prompt_2", None), - ("cross_attention_kwargs", None), - ("num_images_per_prompt", 1), - ("guidance_scale", 5.0), - ("clip_skip", None), + InputParam( + name="prompt", + type_hint=Union[str, List[str]], + description="The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds` instead.", + ), + InputParam( + name="prompt_2", + type_hint=Union[str, List[str]], + description="The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders", + ), + InputParam( + name="negative_prompt", + type_hint=Union[str, List[str]], + description="The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).", + ), + InputParam( + name="negative_prompt_2", + type_hint=Union[str, List[str]], + description="The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders", + ), + InputParam( + name="cross_attention_kwargs", + type_hint=Optional[dict], + description="A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor]", + ), + InputParam( + name="num_images_per_prompt", + type_hint=int, + default=1, + description="The number of images to generate per prompt.", + ), + InputParam( + name="guidance_scale", + type_hint=float, + default=5.0, + description="Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality.", + ), + InputParam( + name="clip_skip", + type_hint=Optional[int], + ), ] @property def intermediates_inputs(self) -> List[str]: - return ["prompt_embeds", "negative_prompt_embeds", "pooled_prompt_embeds", "negative_pooled_prompt_embeds"] + return [InputParam("prompt_embeds"), InputParam("negative_prompt_embeds"), InputParam("pooled_prompt_embeds"), InputParam("negative_pooled_prompt_embeds")] @property def intermediates_outputs(self) -> List[str]: return [ - "prompt_embeds", - "negative_prompt_embeds", - "pooled_prompt_embeds", - "negative_pooled_prompt_embeds", - "dtype", + OutputParam("prompt_embeds"), + OutputParam("negative_prompt_embeds"), + OutputParam("pooled_prompt_embeds"), + OutputParam("negative_pooled_prompt_embeds"), + OutputParam("dtype"), ] def __init__(self): @@ -328,22 +367,25 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock): model_name = "stable-diffusion-xl" @property - def inputs(self) -> List[Tuple[str, Any]]: + def inputs(self) -> List[InputParam]: return [ - ("image", None), - ("generator", None), - ("height", None), - ("width", None), - ("num_images_per_prompt", 1), + InputParam(name="image", required=True), + InputParam(name="generator"), + InputParam(name="height"), + InputParam(name="width"), + InputParam(name="num_images_per_prompt", default=1), ] @property def intermediates_inputs(self) -> List[str]: - return ["batch_size", "dtype", "preprocess_kwargs"] + return [ + InputParam("batch_size", description="batch size for generated image_latents, if not provided, same number of images as input"), + InputParam("dtype"), + InputParam("preprocess_kwargs")] @property def intermediates_outputs(self) -> List[str]: - return ["image_latents"] + return [OutputParam("image_latents")] def __init__(self): super().__init__() @@ -414,24 +456,24 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): model_name = "stable-diffusion-xl" @property - def inputs(self) -> List[Tuple[str, Any]]: + def inputs(self) -> List[InputParam]: return [ - ("num_inference_steps", 50), - ("timesteps", None), - ("sigmas", None), - ("denoising_end", None), - ("strength", 0.3), - ("denoising_start", None), - ("num_images_per_prompt", 1), + InputParam("num_inference_steps", default=50), + InputParam("timesteps"), + InputParam("sigmas"), + InputParam("denoising_end"), + InputParam("strength", default=0.3), + InputParam("denoising_start"), + InputParam("num_images_per_prompt", default=1), ] @property def intermediates_inputs(self) -> List[str]: - return ["batch_size"] + return [InputParam("batch_size", required=True)] @property def intermediates_outputs(self) -> List[str]: - return ["timesteps", "num_inference_steps", "latent_timestep"] + return [OutputParam("timesteps"), OutputParam("num_inference_steps"), OutputParam("latent_timestep")] def __init__(self): super().__init__() @@ -491,17 +533,17 @@ class StableDiffusionXLSetTimestepsStep(PipelineBlock): model_name = "stable-diffusion-xl" @property - def inputs(self) -> List[Tuple[str, Any]]: + def inputs(self) -> List[InputParam]: return [ - ("num_inference_steps", 50), - ("timesteps", None), - ("sigmas", None), - ("denoising_end", None), + InputParam("num_inference_steps", default=50), + InputParam("timesteps"), + InputParam("sigmas"), + InputParam("denoising_end"), ] @property - def intermediates_outputs(self) -> List[str]: - return ["timesteps", "num_inference_steps"] + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("timesteps"), OutputParam("num_inference_steps")] def __init__(self): super().__init__() @@ -541,24 +583,24 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): model_name = "stable-diffusion-xl" @property - def inputs(self) -> List[Tuple[str, Any]]: + def inputs(self) -> List[InputParam]: return [ - ("height", None), - ("width", None), - ("generator", None), - ("num_images_per_prompt", 1), - ("image", None), - ("mask_image", None), - ("padding_mask_crop", None), + InputParam("height"), + InputParam("width"), + InputParam("generator"), + InputParam("num_images_per_prompt", default=1), + InputParam("image", required=True), + InputParam("mask_image", required=True), + InputParam("padding_mask_crop"), ] @property - def intermediates_inputs(self) -> List[str]: - return ["batch_size", "dtype"] + def intermediates_inputs(self) -> List[InputParam]: + return [InputParam("batch_size"), InputParam("dtype")] @property - def intermediates_outputs(self) -> List[str]: - return ["image_latents", "mask", "masked_image_latents", "crops_coords"] + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("image_latents"), OutputParam("mask"), OutputParam("masked_image_latents"), OutputParam("crops_coords")] def __init__(self): super().__init__() @@ -648,20 +690,26 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): @property def inputs(self) -> List[Tuple[str, Any]]: return [ - ("generator", None), - ("latents", None), - ("num_images_per_prompt", 1), - ("denoising_start", None), - ("strength", 0.9999), + InputParam("generator"), + InputParam("latents"), + InputParam("num_images_per_prompt", default=1), + InputParam("denoising_start"), + InputParam("strength", default=0.9999), ] @property def intermediates_inputs(self) -> List[str]: - return ["batch_size", "dtype", "latent_timestep", "image_latents", "mask", "masked_image_latents"] + return [ + InputParam("batch_size", required=True), + InputParam("latent_timestep", required=True), + InputParam("image_latents", required=True), + InputParam("mask", required=True), + InputParam("masked_image_latents"), # only for inpainting-specific unet + InputParam("dtype")] @property def intermediates_outputs(self) -> List[str]: - return ["latents", "mask", "masked_image_latents", "noise"] + return [OutputParam("latents"), OutputParam("mask"), OutputParam("masked_image_latents"), OutputParam("noise")] def __init__(self): super().__init__() @@ -750,19 +798,23 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): @property def inputs(self) -> List[Tuple[str, Any]]: return [ - ("generator", None), - ("latents", None), - ("num_images_per_prompt", 1), - ("denoising_start", None), + InputParam("generator"), + InputParam("latents"), + InputParam("num_images_per_prompt", default=1), + InputParam("denoising_start"), ] @property def intermediates_inputs(self) -> List[str]: - return ["batch_size", "dtype", "latent_timestep", "image_latents"] + return [ + InputParam("latent_timestep", required=True), + InputParam("image_latents", required=True), + InputParam("batch_size"), + InputParam("dtype")] @property - def intermediates_outputs(self) -> List[str]: - return ["latents"] + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("latents")] def __init__(self): super().__init__() @@ -785,6 +837,8 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin if dtype is None: dtype = pipeline.vae.dtype + if batch_size is None: + batch_size = image_latents.shape[0] device = pipeline._execution_device add_noise = True if denoising_start is None else False @@ -812,20 +866,20 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock): @property def inputs(self) -> List[Tuple[str, Any]]: return [ - ("height", None), - ("width", None), - ("generator", None), - ("latents", None), - ("num_images_per_prompt", 1), + InputParam("height"), + InputParam("width"), + InputParam("generator"), + InputParam("latents"), + InputParam("num_images_per_prompt", default=1), ] @property - def intermediates_inputs(self) -> List[str]: - return ["batch_size", "dtype"] + def intermediates_inputs(self) -> List[InputParam]: + return [InputParam("batch_size", required=True), InputParam("dtype")] @property - def intermediates_outputs(self) -> List[str]: - return ["latents"] + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("latents")] def __init__(self): super().__init__() @@ -888,25 +942,25 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): @property def inputs(self) -> List[Tuple[str, Any]]: return [ - ("original_size", None), - ("target_size", None), - ("negative_original_size", None), - ("negative_target_size", None), - ("crops_coords_top_left", (0, 0)), - ("negative_crops_coords_top_left", (0, 0)), - ("num_images_per_prompt", 1), - ("guidance_scale", 5.0), - ("aesthetic_score", 6.0), - ("negative_aesthetic_score", 2.0), + InputParam("original_size"), + InputParam("target_size"), + InputParam("negative_original_size"), + InputParam("negative_target_size"), + InputParam("crops_coords_top_left", default=(0, 0)), + InputParam("negative_crops_coords_top_left", default=(0, 0)), + InputParam("num_images_per_prompt", default=1), + InputParam("guidance_scale", default=5.0), + InputParam("aesthetic_score", default=6.0), + InputParam("negative_aesthetic_score", default=2.0), ] @property - def intermediates_inputs(self) -> List[str]: - return ["latents", "batch_size", "pooled_prompt_embeds"] + def intermediates_inputs(self) -> List[InputParam]: + return [InputParam("latents", required=True), InputParam("pooled_prompt_embeds", required=True)] @property - def intermediates_outputs(self) -> List[str]: - return ["add_time_ids", "negative_add_time_ids", "timestep_cond"] + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("add_time_ids"), OutputParam("negative_add_time_ids"), OutputParam("timestep_cond")] def __init__(self): super().__init__() @@ -928,11 +982,12 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin negative_aesthetic_score = state.get_input("negative_aesthetic_score") latents = state.get_intermediate("latents") - batch_size = state.get_intermediate("batch_size") pooled_prompt_embeds = state.get_intermediate("pooled_prompt_embeds") device = pipeline._execution_device + batch_size = latents.shape[0] + if hasattr(pipeline, "vae") and pipeline.vae is not None: vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1) else: @@ -994,23 +1049,23 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): @property def inputs(self) -> List[Tuple[str, Any]]: return [ - ("original_size", None), - ("target_size", None), - ("negative_original_size", None), - ("negative_target_size", None), - ("crops_coords_top_left", (0, 0)), - ("negative_crops_coords_top_left", (0, 0)), - ("num_images_per_prompt", 1), - ("guidance_scale", 5.0), + InputParam("original_size"), + InputParam("target_size"), + InputParam("negative_original_size"), + InputParam("negative_target_size"), + InputParam("crops_coords_top_left", default=(0, 0)), + InputParam("negative_crops_coords_top_left", default=(0, 0)), + InputParam("num_images_per_prompt", default=1), + InputParam("guidance_scale", default=5.0), ] @property - def intermediates_inputs(self) -> List[str]: - return ["latents", "batch_size", "pooled_prompt_embeds"] + def intermediates_inputs(self) -> List[InputParam]: + return [InputParam("latents", required=True), InputParam("pooled_prompt_embeds", required=True)] @property - def intermediates_outputs(self) -> List[str]: - return ["add_time_ids", "negative_add_time_ids", "timestep_cond"] + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("add_time_ids"), OutputParam("negative_add_time_ids"), OutputParam("timestep_cond")] @torch.no_grad() def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: @@ -1025,11 +1080,12 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin device = state.get_input("device") latents = state.get_intermediate("latents") - batch_size = state.get_intermediate("batch_size") pooled_prompt_embeds = state.get_intermediate("pooled_prompt_embeds") device = pipeline._execution_device + batch_size = latents.shape[0] + height, width = latents.shape[-2:] height = height * pipeline.vae_scale_factor width = width * pipeline.vae_scale_factor @@ -1088,36 +1144,36 @@ class StableDiffusionXLDenoiseStep(PipelineBlock): @property def inputs(self) -> List[Tuple[str, Any]]: return [ - ("guidance_scale", 5.0), - ("guidance_rescale", 0.0), - ("cross_attention_kwargs", None), - ("generator", None), - ("eta", 0.0), - ("guider_kwargs", None), + InputParam("guidance_scale", default=5.0), + InputParam("guidance_rescale", default=0.0), + InputParam("cross_attention_kwargs", default=None), + InputParam("generator", default=None), + InputParam("eta", default=0.0), + InputParam("guider_kwargs", default=None), ] @property def intermediates_inputs(self) -> List[str]: return [ - "latents", - "timesteps", - "num_inference_steps", - "pooled_prompt_embeds", - "negative_pooled_prompt_embeds", - "add_time_ids", - "negative_add_time_ids", - "timestep_cond", - "prompt_embeds", - "negative_prompt_embeds", - "mask", # inpainting - "masked_image_latents", # inpainting - "noise", # inpainting - "image_latents", # inpainting + InputParam("latents", required=True), + InputParam("timesteps", required=True), + InputParam("num_inference_steps", required=True), + InputParam("pooled_prompt_embeds", required=True), + InputParam("negative_pooled_prompt_embeds", required=True), + InputParam("add_time_ids", required=True), + InputParam("negative_add_time_ids", required=True), + InputParam("prompt_embeds", required=True), + InputParam("negative_prompt_embeds", required=True), + InputParam("timestep_cond"), # LCM + InputParam("mask"), # inpainting + InputParam("masked_image_latents"), # inpainting + InputParam("noise"), # inpainting + InputParam("image_latents"), # inpainting ] @property - def intermediates_outputs(self) -> List[str]: - return ["latents"] + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("latents")] def __init__(self): super().__init__() @@ -1143,9 +1199,11 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: add_time_ids = state.get_intermediate("add_time_ids") negative_add_time_ids = state.get_intermediate("negative_add_time_ids") - timestep_cond = state.get_intermediate("timestep_cond") latents = state.get_intermediate("latents") + #LCM + timestep_cond = state.get_intermediate("timestep_cond") + # inpainting mask = state.get_intermediate("mask") masked_image_latents = state.get_intermediate("masked_image_latents") @@ -1272,44 +1330,43 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): @property def inputs(self) -> List[Tuple[str, Any]]: return [ - ("control_image", None), - ("control_guidance_start", 0.0), - ("control_guidance_end", 1.0), - ("controlnet_conditioning_scale", 1.0), - ("guess_mode", False), - ("num_images_per_prompt", 1), - ("guidance_scale", 5.0), - ("guidance_rescale", 0.0), - ("cross_attention_kwargs", None), - ("generator", None), - ("eta", 0.0), - ("guider_kwargs", None), + InputParam("control_image", required=True), + InputParam("control_guidance_start", default=0.0), + InputParam("control_guidance_end", default=1.0), + InputParam("controlnet_conditioning_scale", default=1.0), + InputParam("guess_mode", default=False), + InputParam("num_images_per_prompt", default=1), + InputParam("guidance_scale", default=5.0), + InputParam("guidance_rescale", default=0.0), + InputParam("cross_attention_kwargs", default=None), + InputParam("generator", default=None), + InputParam("eta", default=0.0), + InputParam("guider_kwargs", default=None), ] @property def intermediates_inputs(self) -> List[str]: return [ - "latents", - "batch_size", - "timesteps", - "num_inference_steps", - "prompt_embeds", - "negative_prompt_embeds", - "add_time_ids", - "negative_add_time_ids", - "pooled_prompt_embeds", - "negative_pooled_prompt_embeds", - "timestep_cond", - "mask", - "masked_image_latents", - "noise", - "image_latents", - "crops_coords", + InputParam("latents", required=True), + InputParam("timesteps", required=True), + InputParam("num_inference_steps", required=True), + InputParam("prompt_embeds", required=True), + InputParam("negative_prompt_embeds", required=True), + InputParam("add_time_ids", required=True), + InputParam("negative_add_time_ids", required=True), + InputParam("pooled_prompt_embeds", required=True), + InputParam("negative_pooled_prompt_embeds", required=True), + InputParam("timestep_cond"), # LCM + InputParam("mask"), # inpainting + InputParam("masked_image_latents"), # inpainting + InputParam("noise"), # inpainting + InputParam("image_latents"), # inpainting + InputParam("crops_coords"), # inpainting ] @property - def intermediates_outputs(self) -> List[str]: - return ["latents"] + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("latents")] def __init__(self): super().__init__() @@ -1337,7 +1394,6 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: controlnet_conditioning_scale = state.get_input("controlnet_conditioning_scale") guess_mode = state.get_input("guess_mode") - batch_size = state.get_intermediate("batch_size") latents = state.get_intermediate("latents") timesteps = state.get_intermediate("timesteps") num_inference_steps = state.get_intermediate("num_inference_steps") @@ -1376,6 +1432,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: device = pipeline._execution_device + batch_size = latents.shape[0] height, width = latents.shape[-2:] height = height * pipeline.vae_scale_factor @@ -1594,44 +1651,44 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): @property def inputs(self) -> List[Tuple[str, Any]]: return [ - ("control_image", None), - ("control_guidance_start", 0.0), - ("control_guidance_end", 1.0), - ("controlnet_conditioning_scale", 1.0), - ("control_mode", None), - ("guess_mode", False), - ("num_images_per_prompt", 1), - ("guidance_scale", 5.0), - ("guidance_rescale", 0.0), - ("cross_attention_kwargs", None), - ("generator", None), - ("eta", 0.0), - ("guider_kwargs", None), + (InputParam("control_image", required=True)), + (InputParam("control_guidance_start", default=0.0)), + (InputParam("control_guidance_end", default=1.0)), + (InputParam("controlnet_conditioning_scale", default=1.0)), + (InputParam("control_mode", required=True)), + (InputParam("guess_mode", default=False)), + (InputParam("num_images_per_prompt", default=1)), + (InputParam("guidance_scale", default=5.0)), + (InputParam("guidance_rescale", default=0.0)), + (InputParam("cross_attention_kwargs")), + (InputParam("generator")), + (InputParam("eta", default=0.0)), + (InputParam("guider_kwargs")), ] @property def intermediates_inputs(self) -> List[str]: return [ - "latents", - "batch_size", - "timesteps", - "num_inference_steps", - "prompt_embeds", - "negative_prompt_embeds", - "add_time_ids", - "negative_add_time_ids", - "pooled_prompt_embeds", - "negative_pooled_prompt_embeds", - "timestep_cond", - "mask", - "noise", - "image_latents", - "crops_coords", + InputParam("latents", required=True), + InputParam("timesteps", required=True), + InputParam("num_inference_steps", required=True), + InputParam("prompt_embeds", required=True), + InputParam("negative_prompt_embeds", required=True), + InputParam("add_time_ids", required=True), + InputParam("negative_add_time_ids", required=True), + InputParam("pooled_prompt_embeds", required=True), + InputParam("negative_pooled_prompt_embeds", required=True), + InputParam("timestep_cond"), # LCM + InputParam("mask"), # inpainting + InputParam("masked_image_latents"), # inpainting + InputParam("noise"), # inpainting + InputParam("image_latents"), # inpainting + InputParam("crops_coords"), # inpainting ] @property def intermediates_outputs(self) -> List[str]: - return ["latents"] + return [OutputParam("latents")] def __init__(self): super().__init__() @@ -1660,7 +1717,6 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: control_mode = state.get_input("control_mode") guess_mode = state.get_input("guess_mode") - batch_size = state.get_intermediate("batch_size") latents = state.get_intermediate("latents") timesteps = state.get_intermediate("timesteps") num_inference_steps = state.get_intermediate("num_inference_steps") @@ -1681,7 +1737,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: crops_coords = state.get_intermediate("crops_coords") device = pipeline._execution_device - + batch_size = latents.shape[0] height, width = latents.shape[-2:] height = height * pipeline.vae_scale_factor width = width * pipeline.vae_scale_factor @@ -1882,17 +1938,17 @@ class StableDiffusionXLDecodeLatentsStep(PipelineBlock): @property def inputs(self) -> List[Tuple[str, Any]]: return [ - ("output_type", "pil"), - ("return_dict", True), + (InputParam("output_type", default="pil")), + (InputParam("return_dict", default=True)), ] @property def intermediates_inputs(self) -> List[str]: - return ["latents"] + return [InputParam("latents", required=True)] @property def intermediates_outputs(self) -> List[str]: - return ["images"] + return [OutputParam("images")] def __init__(self): super().__init__() @@ -1961,18 +2017,18 @@ class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock): @property def inputs(self) -> List[Tuple[str, Any]]: return [ - ("image", None), - ("mask_image", None), - ("padding_mask_crop", None), + (InputParam("image", required=True)), + (InputParam("mask_image", required=True)), + (InputParam("padding_mask_crop", default=None)), ] @property def intermediates_inputs(self) -> List[str]: - return ["crops_coords", "images"] + return [InputParam("images", required=True), InputParam("crops_coords")] @property def intermediates_outputs(self) -> List[str]: - return ["images"] + return [OutputParam("images")] @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: @@ -1995,11 +2051,15 @@ class StableDiffusionXLOutputStep(PipelineBlock): @property def inputs(self) -> List[Tuple[str, Any]]: - return [("return_dict", True)] + return [(InputParam("return_dict", default=True))] + @property + def intermediates_inputs(self) -> List[str]: + return [InputParam("images", required=True)] + @property def intermediates_outputs(self) -> List[str]: - return ["images"] + return [OutputParam("images")] @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: From a226920b52a608c1c3ee3f3cbc44cee20bf30b7d Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 17 Jan 2025 01:37:18 +0100 Subject: [PATCH 048/170] get_block_state make it less verbose --- src/diffusers/pipelines/modular_pipeline.py | 92 ++++-- .../pipeline_stable_diffusion_xl_modular.py | 296 +++++++----------- 2 files changed, 182 insertions(+), 206 deletions(-) diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index 8a0ae02aa385..5ead9dbbe5e4 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -18,6 +18,8 @@ from dataclasses import dataclass, field from typing import Any, Dict, List, Tuple, Union +from types import SimpleNamespace + import torch from tqdm.auto import tqdm import re @@ -144,7 +146,7 @@ def format_inputs_short(inputs): return inputs_str -def format_intermediates_short(block) -> str: +def format_intermediates_short(intermediates_inputs: List[InputParam], required_intermediates_inputs: List[str], intermediates_outputs: List[OutputParam]) -> str: """ Formats intermediate inputs and outputs of a block into a string representation. @@ -156,33 +158,31 @@ def format_intermediates_short(block) -> str: """ # Handle inputs input_parts = [] - if hasattr(block, "intermediates_inputs"): - for inp in block.intermediates_inputs: - parts = [] - # Check if input is required - if hasattr(block, "required_intermediates_inputs") and inp.name in block.required_intermediates_inputs: - parts.append("Required") - - # Get base name or modified name - name = inp.name - if hasattr(block, "intermediates_outputs") and name in {out.name for out in block.intermediates_outputs}: - name = f"*{name}" - - # Combine Required() wrapper with possibly starred name - if parts: - input_parts.append(f"Required({name})") - else: - input_parts.append(name) + + for inp in intermediates_inputs: + parts = [] + # Check if input is required + if inp.name in required_intermediates_inputs: + parts.append("Required") + + # Get base name or modified name + name = inp.name + if name in {out.name for out in intermediates_outputs}: + name = f"*{name}" + + # Combine Required() wrapper with possibly starred name + if parts: + input_parts.append(f"Required({name})") + else: + input_parts.append(name) # Handle outputs output_parts = [] - if hasattr(block, "intermediates_outputs"): - outputs = [out.name for out in block.intermediates_outputs] - if hasattr(block, "intermediates_inputs"): - # Only show new outputs if we have inputs - inputs_set = {inp.name for inp in block.intermediates_inputs} - outputs = [out for out in outputs if out not in inputs_set] - output_parts.extend(outputs) + outputs = [out.name for out in intermediates_outputs] + # Only show new outputs if we have inputs + inputs_set = {inp.name for inp in intermediates_inputs} + outputs = [out for out in outputs if out not in inputs_set] + output_parts.extend(outputs) # Combine with arrow notation if both inputs and outputs exist if input_parts and output_parts: @@ -363,6 +363,10 @@ def intermediates_inputs(self) -> List[InputParam]: @property def intermediates_outputs(self) -> List[OutputParam]: return [] + + @property + def outputs(self) -> List[OutputParam]: + return [] @property def required_inputs(self) -> List[str]: @@ -424,7 +428,7 @@ def __repr__(self): inputs = "Inputs:\n " + inputs_str # Intermediates section - intermediates_str = format_intermediates_short(self) + intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs) intermediates = f"Intermediates:\n {intermediates_str}" return ( @@ -465,6 +469,36 @@ def get_doc_string(self): return output + def get_block_state(self, state: PipelineState) -> dict: + """Get all inputs and intermediates in one dictionary""" + data = {} + + # Check inputs + for input_param in self.inputs: + value = state.get_input(input_param.name) + if input_param.required and value is None: + raise ValueError(f"Required input '{input_param.name}' is missing") + data[input_param.name] = value + + # Check intermediates + for input_param in self.intermediates_inputs: + value = state.get_intermediate(input_param.name) + if input_param.required and value is None: + raise ValueError(f"Required intermediate input '{input_param.name}' is missing") + data[input_param.name] = value + + return SimpleNamespace(**data) + + def add_block_state(self, state: PipelineState, block_state: SimpleNamespace): + for output_param in self.intermediates_outputs: + if not hasattr(block_state, output_param.name): + raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") + state.add_intermediate(output_param.name, getattr(block_state, output_param.name)) + + for output_param in self.outputs: + if not hasattr(block_state, output_param.name): + raise ValueError(f"Output '{output_param.name}' is missing in block state") + state.add_output(output_param.name, getattr(block_state, output_param.name)) def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]: @@ -772,7 +806,7 @@ def __repr__(self): inputs_str = format_inputs_short(block.inputs) sections.append(f" inputs:\n {inputs_str}") - intermediates_str = f" intermediates:\n {format_intermediates_short(block)}" + intermediates_str = f" intermediates:\n {format_intermediates_short(block.intermediates_inputs, block.required_intermediates_inputs, block.intermediates_outputs)}" sections.append(intermediates_str) sections.append("") @@ -1112,7 +1146,7 @@ def __repr__(self): blocks_str += f" inputs: {inputs_str}\n" - intermediates_str = format_intermediates_short(block) + intermediates_str = format_intermediates_short(block.intermediates_inputs, block.required_intermediates_inputs, block.intermediates_outputs) if intermediates_str: blocks_str += f" intermediates: {intermediates_str}\n" @@ -1122,7 +1156,7 @@ def __repr__(self): inputs_str = " Inputs:\n " + inputs_str final_intermediates_outputs = [out.name for out in self.final_intermediates_outputs] - intermediates_str_short = format_intermediates_short(self) + intermediates_str_short = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs) intermediates_input_str = intermediates_str_short.split('->')[0].strip() # "Required(latents), crops_coords" intermediates_output_str = intermediates_str_short.split('->')[1].strip() intermediates_str = ( diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 4256ffa3d463..ef7ac79d9aba 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -299,66 +299,51 @@ def check_inputs( @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: - # Get inputs - prompt = state.get_input("prompt") - prompt_2 = state.get_input("prompt_2") - negative_prompt = state.get_input("negative_prompt") - negative_prompt_2 = state.get_input("negative_prompt_2") - cross_attention_kwargs = state.get_input("cross_attention_kwargs") - num_images_per_prompt = state.get_input("num_images_per_prompt") - guidance_scale = state.get_input("guidance_scale") - clip_skip = state.get_input("clip_skip") - - prompt_embeds = state.get_intermediate("prompt_embeds") - negative_prompt_embeds = state.get_intermediate("negative_prompt_embeds") - pooled_prompt_embeds = state.get_intermediate("pooled_prompt_embeds") - negative_pooled_prompt_embeds = state.get_intermediate("negative_pooled_prompt_embeds") + # Get inputs and intermediates + data = self.get_block_state(state) - do_classifier_free_guidance = guidance_scale > 1.0 - device = pipeline._execution_device + data.do_classifier_free_guidance = data.guidance_scale > 1.0 + data.device = pipeline._execution_device self.check_inputs( pipeline, - prompt, - prompt_2, - negative_prompt, - negative_prompt_2, - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, + data.prompt, + data.prompt_2, + data.negative_prompt, + data.negative_prompt_2, + data.prompt_embeds, + data.negative_prompt_embeds, + data.pooled_prompt_embeds, + data.negative_pooled_prompt_embeds, ) # Encode input prompt - text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + data.text_encoder_lora_scale = ( + data.cross_attention_kwargs.get("scale", None) if data.cross_attention_kwargs is not None else None ) ( - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, + data.prompt_embeds, + data.negative_prompt_embeds, + data.pooled_prompt_embeds, + data.negative_pooled_prompt_embeds, ) = pipeline.encode_prompt( - prompt, - prompt_2, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - negative_prompt_2, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - lora_scale=text_encoder_lora_scale, - clip_skip=clip_skip, + data.prompt, + data.prompt_2, + data.device, + data.num_images_per_prompt, + data.do_classifier_free_guidance, + data.negative_prompt, + data.negative_prompt_2, + prompt_embeds=data.prompt_embeds, + negative_prompt_embeds=data.negative_prompt_embeds, + pooled_prompt_embeds=data.pooled_prompt_embeds, + negative_pooled_prompt_embeds=data.negative_pooled_prompt_embeds, + lora_scale=data.text_encoder_lora_scale, + clip_skip=data.clip_skip, ) + data.dtype = data.prompt_embeds.dtype # Add outputs - state.add_intermediate("prompt_embeds", prompt_embeds) - state.add_intermediate("negative_prompt_embeds", negative_prompt_embeds) - state.add_intermediate("pooled_prompt_embeds", pooled_prompt_embeds) - state.add_intermediate("negative_pooled_prompt_embeds", negative_pooled_prompt_embeds) - state.add_intermediate("dtype", prompt_embeds.dtype) + self.add_block_state(state, data) return pipeline, state @@ -394,59 +379,48 @@ def __init__(self): @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: - num_images_per_prompt = state.get_input("num_images_per_prompt") - generator = state.get_input("generator") - - height = state.get_input("height") - width = state.get_input("width") - image = state.get_input("image") - - preprocess_kwargs = state.get_intermediate("preprocess_kwargs") or {} - batch_size = state.get_intermediate("batch_size") - dtype = state.get_intermediate("dtype") - - device = pipeline._execution_device - if dtype is None: - dtype = pipeline.vae.dtype + data = self.get_block_state(state) + data.preprocess_kwargs = data.preprocess_kwargs or {} + data.device = pipeline._execution_device + data.dtype = data.dtype if data.dtype is not None else pipeline.vae.dtype - image = pipeline.image_processor.preprocess(image, height=height, width=width, **preprocess_kwargs) - image = image.to(device=device, dtype=dtype) + data.image = pipeline.image_processor.preprocess(data.image, height=data.height, width=data.width, **data.preprocess_kwargs) + data.image = data.image.to(device=data.device, dtype=data.dtype) - if batch_size is None: - batch_size = image.shape[0] + data.batch_size = data.batch_size if data.batch_size is not None else data.image.shape[0] - batch_size = batch_size * num_images_per_prompt + data.batch_size = data.batch_size * data.num_images_per_prompt # if generator is a list, make sure the length of it matches the length of images (both should be batch_size) - if isinstance(generator, list) and len(generator) != batch_size: + if isinstance(data.generator, list) and len(data.generator) != data.batch_size: raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." + f"You have passed a list of generators of length {len(data.generator)}, but requested an effective batch" + f" size of {data.batch_size}. Make sure the batch size matches the length of the generators." ) - elif isinstance(generator, list): - if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: - image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) - elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + elif isinstance(data.generator, list): + if data.image.shape[0] < data.batch_size and data.batch_size % data.image.shape[0] == 0: + data.image = torch.cat([data.image] * (data.batch_size // data.image.shape[0]), dim=0) + elif data.image.shape[0] < data.batch_size and data.batch_size % data.image.shape[0] != 0: raise ValueError( - f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + f"Cannot duplicate `image` of batch size {data.image.shape[0]} to effective batch_size {data.batch_size} " ) - image_latents = pipeline._encode_vae_image(image=image, generator=generator) + data.image_latents = pipeline._encode_vae_image(image=data.image, generator=data.generator) - if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + if data.batch_size > data.image_latents.shape[0] and data.batch_size % data.image_latents.shape[0] == 0: # expand latents for batch_size - additional_image_per_prompt = batch_size // image_latents.shape[0] - image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) - elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + data.additional_image_per_prompt = data.batch_size // data.image_latents.shape[0] + data.image_latents = torch.cat([data.image_latents] * additional_image_per_prompt, dim=0) + elif data.batch_size > data.image_latents.shape[0] and data.batch_size % data.image_latents.shape[0] != 0: raise ValueError( - f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + f"Cannot duplicate `image` of batch size {data.image_latents.shape[0]} to {data.batch_size} text prompts." ) else: - image_latents = torch.cat([image_latents], dim=0) + data.image_latents = torch.cat([data.image_latents], dim=0) - state.add_intermediate("image_latents", image_latents) + self.add_block_state(state, data) return pipeline, state @@ -481,49 +455,36 @@ def __init__(self): @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: - num_inference_steps = state.get_input("num_inference_steps") - timesteps = state.get_input("timesteps") - sigmas = state.get_input("sigmas") - denoising_end = state.get_input("denoising_end") - - # image to image only - strength = state.get_input("strength") - denoising_start = state.get_input("denoising_start") - num_images_per_prompt = state.get_input("num_images_per_prompt") - - # image to image only - batch_size = state.get_intermediate("batch_size") + data = self.get_block_state(state) - device = pipeline._execution_device + data.device = pipeline._execution_device - timesteps, num_inference_steps = retrieve_timesteps( - pipeline.scheduler, num_inference_steps, device, timesteps, sigmas + data.timesteps, data.num_inference_steps = retrieve_timesteps( + pipeline.scheduler, data.num_inference_steps, data.device, data.timesteps, data.sigmas ) def denoising_value_valid(dnv): return isinstance(dnv, float) and 0 < dnv < 1 - timesteps, num_inference_steps = pipeline.get_timesteps( - num_inference_steps, - strength, - device, - denoising_start=denoising_start if denoising_value_valid(denoising_start) else None, + data.timesteps, data.num_inference_steps = pipeline.get_timesteps( + data.num_inference_steps, + data.strength, + data.device, + denoising_start=data.denoising_start if denoising_value_valid(data.denoising_start) else None, ) - latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + data.latent_timestep = data.timesteps[:1].repeat(data.batch_size * data.num_images_per_prompt) - if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: - discrete_timestep_cutoff = int( + if data.denoising_end is not None and isinstance(data.denoising_end, float) and data.denoising_end > 0 and data.denoising_end < 1: + data.discrete_timestep_cutoff = int( round( pipeline.scheduler.config.num_train_timesteps - - (denoising_end * pipeline.scheduler.config.num_train_timesteps) + - (data.denoising_end * pipeline.scheduler.config.num_train_timesteps) ) ) - num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) - timesteps = timesteps[:num_inference_steps] + data.num_inference_steps = len(list(filter(lambda ts: ts >= data.discrete_timestep_cutoff, data.timesteps))) + data.timesteps = data.timesteps[:data.num_inference_steps] - state.add_intermediate("timesteps", timesteps) - state.add_intermediate("num_inference_steps", num_inference_steps) - state.add_intermediate("latent_timestep", latent_timestep) + self.add_block_state(state, data) return pipeline, state @@ -551,30 +512,25 @@ def __init__(self): @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: - num_inference_steps = state.get_input("num_inference_steps") - timesteps = state.get_input("timesteps") - sigmas = state.get_input("sigmas") - denoising_end = state.get_input("denoising_end") + data = self.get_block_state(state) - device = pipeline._execution_device + data.device = pipeline._execution_device - timesteps, num_inference_steps = retrieve_timesteps( - pipeline.scheduler, num_inference_steps, device, timesteps, sigmas + data.timesteps, data.num_inference_steps = retrieve_timesteps( + pipeline.scheduler, data.num_inference_steps, data.device, data.timesteps, data.sigmas ) - if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: - discrete_timestep_cutoff = int( + if data.denoising_end is not None and isinstance(data.denoising_end, float) and data.denoising_end > 0 and data.denoising_end < 1: + data.discrete_timestep_cutoff = int( round( pipeline.scheduler.config.num_train_timesteps - - (denoising_end * pipeline.scheduler.config.num_train_timesteps) + - (data.denoising_end * pipeline.scheduler.config.num_train_timesteps) ) ) - num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) - timesteps = timesteps[:num_inference_steps] - - state.add_intermediate("timesteps", timesteps) - state.add_intermediate("num_inference_steps", num_inference_steps) + data.num_inference_steps = len(list(filter(lambda ts: ts >= data.discrete_timestep_cutoff, data.timesteps))) + data.timesteps = data.timesteps[:data.num_inference_steps] + self.add_block_state(state, data) return pipeline, state @@ -611,73 +567,55 @@ def __init__(self): @torch.no_grad() def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - num_images_per_prompt = state.get_input("num_images_per_prompt") - # YiYi TODO: we don't put generator back to state but it actually gets used and updated - # it is ok but think about how we can handle mutable inputs better in PipelineState so user would be aware - generator = state.get_input("generator") - - height = state.get_input("height") - width = state.get_input("width") - # inpaint only - image = state.get_input("image") - padding_mask_crop = state.get_input("padding_mask_crop") - mask_image = state.get_input("mask_image") - - batch_size = state.get_intermediate("batch_size") - dtype = state.get_intermediate("dtype") + data = self.get_block_state(state) - if dtype is None: - dtype = pipeline.vae.dtype - device = pipeline._execution_device + data.dtype = data.dtype if data.dtype is not None else pipeline.vae.dtype + data.device = pipeline._execution_device - if padding_mask_crop is not None: - crops_coords = pipeline.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) - resize_mode = "fill" + if data.padding_mask_crop is not None: + data.crops_coords = pipeline.mask_processor.get_crop_region(data.mask_image, data.width, data.height, pad=data.padding_mask_crop) + data.resize_mode = "fill" else: - crops_coords = None - resize_mode = "default" + data.crops_coords = None + data.resize_mode = "default" - image = pipeline.image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode) - image = image.to(dtype=torch.float32) + data.image = pipeline.image_processor.preprocess(data.image, height=data.height, width=data.width, crops_coords=data.crops_coords, resize_mode=data.resize_mode) + data.image = data.image.to(dtype=torch.float32) - mask = pipeline.mask_processor.preprocess(mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords) - masked_image = image * (mask < 0.5) + data.mask = pipeline.mask_processor.preprocess(data.mask_image, height=data.height, width=data.width, resize_mode=data.resize_mode, crops_coords=data.crops_coords) + data.masked_image = data.image * (data.mask < 0.5) - if batch_size is None: - batch_size = image.shape[0] + data.batch_size = data.batch_size if data.batch_size is not None else data.image.shape[0] - batch_size = batch_size * num_images_per_prompt - image = image.to(device=device, dtype=dtype) - image_latents = pipeline._encode_vae_image(image=image, generator=generator) + data.batch_size = data.batch_size * data.num_images_per_prompt + data.image = data.image.to(device=data.device, dtype=data.dtype) + data.image_latents = pipeline._encode_vae_image(image=data.image, generator=data.generator) - if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + if data.batch_size > data.image_latents.shape[0] and data.batch_size % data.image_latents.shape[0] == 0: # expand latents for batch_size - additional_image_per_prompt = batch_size // image_latents.shape[0] - image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) - elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + data.additional_image_per_prompt = data.batch_size // data.image_latents.shape[0] + data.image_latents = torch.cat([data.image_latents] * data.additional_image_per_prompt, dim=0) + elif data.batch_size > data.image_latents.shape[0] and data.batch_size % data.image_latents.shape[0] != 0: raise ValueError( - f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + f"Cannot duplicate `image` of batch size {data.image_latents.shape[0]} to {data.batch_size} text prompts." ) else: - image_latents = torch.cat([image_latents], dim=0) + data.image_latents = torch.cat([data.image_latents], dim=0) # 7. Prepare mask latent variables - mask, masked_image_latents = pipeline.prepare_mask_latents( - mask, - masked_image, - batch_size, - height, - width, - dtype, - device, - generator, + data.mask, data.masked_image_latents = pipeline.prepare_mask_latents( + data.mask, + data.masked_image, + data.batch_size, + data.height, + data.width, + data.dtype, + data.device, + data.generator, ) - state.add_intermediate("mask", mask) - state.add_intermediate("masked_image_latents", masked_image_latents) - state.add_intermediate("image_latents", image_latents) - state.add_intermediate("crops_coords", crops_coords) + self.add_block_state(state, data) return pipeline, state @@ -2061,6 +1999,10 @@ def intermediates_inputs(self) -> List[str]: def intermediates_outputs(self) -> List[str]: return [OutputParam("images")] + @property + def outputs(self) -> List[Tuple[str, Any]]: + return [(OutputParam("images"))] + @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: images = state.get_intermediate("images") From 77b5fa59c54b3fbe7898a1362cfd4991cd087c96 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 18 Jan 2025 04:12:07 +0100 Subject: [PATCH 049/170] make it work with lora has both text_encoder & unet --- src/diffusers/loaders/lora_base.py | 2 +- src/diffusers/loaders/lora_pipeline.py | 24 ++++++++++++------------ 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 286d0a12bc71..d9a0ca9cdc9c 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -330,7 +330,7 @@ def _optionally_disable_offloading(cls, _pipeline): is_model_cpu_offload = False is_sequential_cpu_offload = False - if _pipeline is not None and _pipeline.hf_device_map is None: + if _pipeline is not None and hasattr(_pipeline, "hf_device_map") and _pipeline.hf_device_map is None: for _, component in _pipeline.components.items(): if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): if not is_model_cpu_offload: diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 351295e938ff..be7f26bf6b99 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -644,24 +644,24 @@ def load_lora_weights( # First, ensure that the checkpoint is a compatible one and can be successfully loaded. state_dict, network_alphas = self.lora_state_dict( pretrained_model_name_or_path_or_dict, - unet_config=self.unet.config, + unet_config=self.unet.config if hasattr(self, "unet") else None, **kwargs, ) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") - - self.load_lora_into_unet( - state_dict, - network_alphas=network_alphas, - unet=self.unet, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - ) + if hasattr(self, "unet"): + self.load_lora_into_unet( + state_dict, + network_alphas=network_alphas, + unet=self.unet, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + ) text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} - if len(text_encoder_state_dict) > 0: + if len(text_encoder_state_dict) > 0 and hasattr(self, "text_encoder"): self.load_lora_into_text_encoder( text_encoder_state_dict, network_alphas=network_alphas, @@ -674,7 +674,7 @@ def load_lora_weights( ) text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k} - if len(text_encoder_2_state_dict) > 0: + if len(text_encoder_2_state_dict) > 0 and hasattr(self, "text_encoder_2"): self.load_lora_into_text_encoder( text_encoder_2_state_dict, network_alphas=network_alphas, From 6e2fe26bfd46a65e3044b44b75df24f489f03c39 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 18 Jan 2025 08:04:12 +0100 Subject: [PATCH 050/170] fix more for lora --- src/diffusers/loaders/lora_base.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index d9a0ca9cdc9c..a45100e9298a 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -557,8 +557,10 @@ def set_adapters( # Decompose weights into weights for denoiser and text encoders. _component_adapter_weights = {} for component in self._lora_loadable_modules: - model = getattr(self, component) - + model = getattr(self, component, None) + if model is None: + logger.warning(f"Model {component} not found in pipeline.") + continue for adapter_name, weights in zip(adapter_names, adapter_weights): if isinstance(weights, dict): component_adapter_weights = weights.pop(component, None) From 68a5185c86716129aaca6c42b900c512d3446961 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 20 Jan 2025 03:36:01 +0100 Subject: [PATCH 051/170] refactor more, ipadapter node, lora node --- src/diffusers/guider.py | 6 +- src/diffusers/loaders/__init__.py | 2 + src/diffusers/loaders/ip_adapter.py | 256 +++ .../pipeline_stable_diffusion_xl_modular.py | 1640 ++++++++--------- 4 files changed, 1047 insertions(+), 857 deletions(-) diff --git a/src/diffusers/guider.py b/src/diffusers/guider.py index e5942ff560b1..7445b7ba97af 100644 --- a/src/diffusers/guider.py +++ b/src/diffusers/guider.py @@ -169,7 +169,7 @@ def prepare_input( else: negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input) - if not self._is_prepared_input(cond_input) and negative_cond_input is None: + if not self._is_prepared_input(cond_input) and self.do_classifier_free_guidance and negative_cond_input is None: raise ValueError( "`negative_cond_input` is required when cond_input does not already contains negative conditional input" ) @@ -447,7 +447,7 @@ def prepare_input( else: negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input) - if not self._is_prepared_input(cond_input) and negative_cond_input is None: + if not self._is_prepared_input(cond_input) and self.do_perturbed_attention_guidance and negative_cond_input is None: raise ValueError( "`negative_cond_input` is required when cond_input does not already contains negative conditional input" ) @@ -688,7 +688,7 @@ def prepare_input( else: negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input) - if not self._is_prepared_input(cond_input) and negative_cond_input is None: + if not self._is_prepared_input(cond_input) and self.do_classifier_free_guidance and negative_cond_input is None: raise ValueError( "`negative_cond_input` is required when cond_input does not already contains negative conditional input" ) diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 2db8b53db498..65ef5d3e2336 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -79,6 +79,7 @@ def text_encoder_attn_modules(text_encoder): "IPAdapterMixin", "FluxIPAdapterMixin", "SD3IPAdapterMixin", + "ModularIPAdapterMixin", ] _import_structure["peft"] = ["PeftAdapterMixin"] @@ -97,6 +98,7 @@ def text_encoder_attn_modules(text_encoder): FluxIPAdapterMixin, IPAdapterMixin, SD3IPAdapterMixin, + ModularIPAdapterMixin, ) from .lora_pipeline import ( AmusedLoraLoaderMixin, diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index 7b691d1fe16e..895dce22dc12 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -354,6 +354,262 @@ def unload_ip_adapter(self): ) self.unet.set_attn_processor(attn_procs) +class ModularIPAdapterMixin: + """Mixin for handling IP Adapters.""" + + @validate_hf_hub_args + def load_ip_adapter( + self, + pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]], + subfolder: Union[str, List[str]], + weight_name: Union[str, List[str]], + **kwargs, + ): + """ + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + subfolder (`str` or `List[str]`): + The subfolder location of a model file within a larger model repository on the Hub or locally. If a + list is passed, it should have the same length as `weight_name`. + weight_name (`str` or `List[str]`): + The name of the weight file to load. If a list is passed, it should have the same length as + `subfolder`. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + """ + + # handle the list inputs for multiple IP Adapters + if not isinstance(weight_name, list): + weight_name = [weight_name] + + if not isinstance(pretrained_model_name_or_path_or_dict, list): + pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict] + if len(pretrained_model_name_or_path_or_dict) == 1: + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name) + + if not isinstance(subfolder, list): + subfolder = [subfolder] + if len(subfolder) == 1: + subfolder = subfolder * len(weight_name) + + if len(weight_name) != len(pretrained_model_name_or_path_or_dict): + raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.") + + if len(weight_name) != len(subfolder): + raise ValueError("`weight_name` and `subfolder` must have the same length.") + + # Load the main state dict first. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + + if low_cpu_mem_usage and not is_accelerate_available(): + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + state_dicts = [] + for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip( + pretrained_model_name_or_path_or_dict, weight_name, subfolder + ): + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + if weight_name.endswith(".safetensors"): + state_dict = {"image_proj": {}, "ip_adapter": {}} + with safe_open(model_file, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith("image_proj."): + state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) + elif key.startswith("ip_adapter."): + state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) + else: + state_dict = load_state_dict(model_file) + else: + state_dict = pretrained_model_name_or_path_or_dict + + keys = list(state_dict.keys()) + if "image_proj" not in keys and "ip_adapter" not in keys: + raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.") + + state_dicts.append(state_dict) + + # create feature extractor if it has not been registered to the pipeline yet + if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None: + # FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224 + default_clip_size = 224 + clip_image_size = ( + self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size + ) + feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size) + + unet_name = getattr(self, "unet_name", "unet") + unet = getattr(self, unet_name) + unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) + + extra_loras = unet._load_ip_adapter_loras(state_dicts) + if extra_loras != {}: + if not USE_PEFT_BACKEND: + logger.warning("PEFT backend is required to load these weights.") + else: + # apply the IP Adapter Face ID LoRA weights + peft_config = getattr(unet, "peft_config", {}) + for k, lora in extra_loras.items(): + if f"faceid_{k}" not in peft_config: + self.load_lora_weights(lora, adapter_name=f"faceid_{k}") + self.set_adapters([f"faceid_{k}"], adapter_weights=[1.0]) + + def set_ip_adapter_scale(self, scale): + """ + Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for + granular control over each IP-Adapter behavior. A config can be a float or a dictionary. + + Example: + + ```py + # To use original IP-Adapter + scale = 1.0 + pipeline.set_ip_adapter_scale(scale) + + # To use style block only + scale = { + "up": {"block_0": [0.0, 1.0, 0.0]}, + } + pipeline.set_ip_adapter_scale(scale) + + # To use style+layout blocks + scale = { + "down": {"block_2": [0.0, 1.0]}, + "up": {"block_0": [0.0, 1.0, 0.0]}, + } + pipeline.set_ip_adapter_scale(scale) + + # To use style and layout from 2 reference images + scales = [{"down": {"block_2": [0.0, 1.0]}}, {"up": {"block_0": [0.0, 1.0, 0.0]}}] + pipeline.set_ip_adapter_scale(scales) + ``` + """ + unet_name = getattr(self, "unet_name", "unet") + unet = getattr(self, unet_name) + if not isinstance(scale, list): + scale = [scale] + scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0) + + for attn_name, attn_processor in unet.attn_processors.items(): + if isinstance( + attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor) + ): + if len(scale_configs) != len(attn_processor.scale): + raise ValueError( + f"Cannot assign {len(scale_configs)} scale_configs to " + f"{len(attn_processor.scale)} IP-Adapter." + ) + elif len(scale_configs) == 1: + scale_configs = scale_configs * len(attn_processor.scale) + for i, scale_config in enumerate(scale_configs): + if isinstance(scale_config, dict): + for k, s in scale_config.items(): + if attn_name.startswith(k): + attn_processor.scale[i] = s + else: + attn_processor.scale[i] = scale_config + + def unload_ip_adapter(self): + """ + Unloads the IP Adapter weights + + Examples: + + ```python + >>> # Assuming `pipeline` is already loaded with the IP Adapter weights. + >>> pipeline.unload_ip_adapter() + >>> ... + ``` + """ + + # remove hidden encoder + self.unet.encoder_hid_proj = None + self.unet.config.encoder_hid_dim_type = None + + # Kolors: restore `encoder_hid_proj` with `text_encoder_hid_proj` + if hasattr(self.unet, "text_encoder_hid_proj") and self.unet.text_encoder_hid_proj is not None: + self.unet.encoder_hid_proj = self.unet.text_encoder_hid_proj + self.unet.text_encoder_hid_proj = None + self.unet.config.encoder_hid_dim_type = "text_proj" + + # restore original Unet attention processors layers + attn_procs = {} + for name, value in self.unet.attn_processors.items(): + attn_processor_class = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor() + ) + attn_procs[name] = ( + attn_processor_class + if isinstance( + value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor) + ) + else value.__class__() + ) + self.unet.set_attn_processor(attn_procs) + class FluxIPAdapterMixin: """Mixin for handling Flux IP Adapters.""" diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index ef7ac79d9aba..ab171e365199 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -21,7 +21,7 @@ from ...guider import CFGGuider from ...image_processor import VaeImageProcessor -from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin from ...models import ControlNetModel, ImageProjection from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor from ...models.lora import adjust_lora_scale_text_encoder @@ -133,33 +133,106 @@ class StableDiffusionXLInputStep(PipelineBlock): def inputs(self) -> List[InputParam]: return [ InputParam( - name="prompt", - description="The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds` instead.", - type_hint=Union[str, List[str]], + name="num_images_per_prompt", + type_hint=int, + default=1, + description="The number of images to generate per prompt.", ), ] @property def intermediates_inputs(self) -> List[str]: - return [InputParam("prompt_embeds")] + return [ + InputParam("prompt_embeds", required=True), + InputParam("negative_prompt_embeds"), + InputParam("pooled_prompt_embeds", required=True), + InputParam("negative_pooled_prompt_embeds"), + InputParam("ip_adapter_embeds"), + InputParam("negative_ip_adapter_embeds"), + ] @property def intermediates_outputs(self) -> List[str]: - return [OutputParam("batch_size")] + return [ + OutputParam("batch_size"), + OutputParam("dtype"), + OutputParam("prompt_embeds"), + OutputParam("negative_prompt_embeds"), + OutputParam("pooled_prompt_embeds"), + OutputParam("negative_pooled_prompt_embeds"), + OutputParam("ip_adapter_embeds"), + OutputParam("negative_ip_adapter_embeds"), + ] + + def check_inputs(self, pipeline, data): + + if data.prompt_embeds is not None and data.negative_prompt_embeds is not None: + if data.prompt_embeds.shape != data.negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {data.prompt_embeds.shape} != `negative_prompt_embeds`" + f" {data.negative_prompt_embeds.shape}." + ) + + if data.prompt_embeds is not None and data.pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if data.negative_prompt_embeds is not None and data.negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if data.ip_adapter_embeds is not None and not isinstance(data.ip_adapter_embeds, list): + raise ValueError("`ip_adapter_embeds` must be a list") + + if data.negative_ip_adapter_embeds is not None and not isinstance(data.negative_ip_adapter_embeds, list): + raise ValueError("`negative_ip_adapter_embeds` must be a list") + + if data.ip_adapter_embeds is not None and data.negative_ip_adapter_embeds is not None: + for i, ip_adapter_embed in enumerate(data.ip_adapter_embeds): + if ip_adapter_embed.shape != data.negative_ip_adapter_embeds[i].shape: + raise ValueError( + "`ip_adapter_embeds` and `negative_ip_adapter_embeds` must have the same shape when passed directly, but" + f" got: `ip_adapter_embeds` {ip_adapter_embed.shape} != `negative_ip_adapter_embeds`" + f" {data.negative_ip_adapter_embeds[i].shape}." + ) @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: - prompt = state.get_input("prompt") - prompt_embeds = state.get_intermediate("prompt_embeds") + data = self.get_block_state(state) + self.check_inputs(pipeline, data) - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] + data.batch_size = data.prompt_embeds.shape[0] + data.dtype = data.prompt_embeds.dtype - state.add_intermediate("batch_size", batch_size) + _, seq_len, _ = data.prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + data.prompt_embeds = data.prompt_embeds.repeat(1, data.num_images_per_prompt, 1) + data.prompt_embeds = data.prompt_embeds.view(data.batch_size * data.num_images_per_prompt, seq_len, -1) + + if data.negative_prompt_embeds is not None: + _, seq_len, _ = data.negative_prompt_embeds.shape + data.negative_prompt_embeds = data.negative_prompt_embeds.repeat(1, data.num_images_per_prompt, 1) + data.negative_prompt_embeds = data.negative_prompt_embeds.view(data.batch_size * data.num_images_per_prompt, seq_len, -1) + + data.pooled_prompt_embeds = data.pooled_prompt_embeds.repeat(1, data.num_images_per_prompt, 1) + data.pooled_prompt_embeds = data.pooled_prompt_embeds.view(data.batch_size * data.num_images_per_prompt, -1) + + if data.negative_pooled_prompt_embeds is not None: + data.negative_pooled_prompt_embeds = data.negative_pooled_prompt_embeds.repeat(1, data.num_images_per_prompt, 1) + data.negative_pooled_prompt_embeds = data.negative_pooled_prompt_embeds.view(data.batch_size * data.num_images_per_prompt, -1) + + if data.ip_adapter_embeds is not None: + for i, ip_adapter_embed in enumerate(data.ip_adapter_embeds): + data.ip_adapter_embeds[i] = torch.cat([ip_adapter_embed] * data.num_images_per_prompt, dim=0) + + if data.negative_ip_adapter_embeds is not None: + for i, negative_ip_adapter_embed in enumerate(data.negative_ip_adapter_embeds): + data.negative_ip_adapter_embeds[i] = torch.cat([negative_ip_adapter_embed] * data.num_images_per_prompt, dim=0) + + self.add_block_state(state, data) return pipeline, state @@ -197,12 +270,6 @@ def inputs(self) -> List[InputParam]: type_hint=Optional[dict], description="A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor]", ), - InputParam( - name="num_images_per_prompt", - type_hint=int, - default=1, - description="The number of images to generate per prompt.", - ), InputParam( name="guidance_scale", type_hint=float, @@ -215,9 +282,6 @@ def inputs(self) -> List[InputParam]: ), ] - @property - def intermediates_inputs(self) -> List[str]: - return [InputParam("prompt_embeds"), InputParam("negative_prompt_embeds"), InputParam("pooled_prompt_embeds"), InputParam("negative_pooled_prompt_embeds")] @property def intermediates_outputs(self) -> List[str]: @@ -226,7 +290,6 @@ def intermediates_outputs(self) -> List[str]: OutputParam("negative_prompt_embeds"), OutputParam("pooled_prompt_embeds"), OutputParam("negative_pooled_prompt_embeds"), - OutputParam("dtype"), ] def __init__(self): @@ -237,85 +300,23 @@ def __init__(self): self.components["tokenizer"] = None self.components["tokenizer_2"] = None - @staticmethod - def check_inputs( - pipeline, - prompt, - prompt_2, - negative_prompt=None, - negative_prompt_2=None, - prompt_embeds=None, - negative_prompt_embeds=None, - pooled_prompt_embeds=None, - negative_pooled_prompt_embeds=None, - ): - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt_2 is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): - raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + def check_inputs(self, pipeline, data): - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - elif negative_prompt_2 is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) + if data.prompt is not None and (not isinstance(data.prompt, str) and not isinstance(data.prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(data.prompt)}") + elif data.prompt_2 is not None and (not isinstance(data.prompt_2, str) and not isinstance(data.prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(data.prompt_2)}") - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - if prompt_embeds is not None and pooled_prompt_embeds is None: - raise ValueError( - "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." - ) - - if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: - raise ValueError( - "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." - ) @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: # Get inputs and intermediates data = self.get_block_state(state) + self.check_inputs(pipeline, data) data.do_classifier_free_guidance = data.guidance_scale > 1.0 data.device = pipeline._execution_device - self.check_inputs( - pipeline, - data.prompt, - data.prompt_2, - data.negative_prompt, - data.negative_prompt_2, - data.prompt_embeds, - data.negative_prompt_embeds, - data.pooled_prompt_embeds, - data.negative_pooled_prompt_embeds, - ) # Encode input prompt data.text_encoder_lora_scale = ( @@ -330,18 +331,17 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.prompt, data.prompt_2, data.device, - data.num_images_per_prompt, + 1, data.do_classifier_free_guidance, data.negative_prompt, data.negative_prompt_2, - prompt_embeds=data.prompt_embeds, - negative_prompt_embeds=data.negative_prompt_embeds, - pooled_prompt_embeds=data.pooled_prompt_embeds, - negative_pooled_prompt_embeds=data.negative_pooled_prompt_embeds, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, lora_scale=data.text_encoder_lora_scale, clip_skip=data.clip_skip, ) - data.dtype = data.prompt_embeds.dtype # Add outputs self.add_block_state(state, data) return pipeline, state @@ -358,13 +358,11 @@ def inputs(self) -> List[InputParam]: InputParam(name="generator"), InputParam(name="height"), InputParam(name="width"), - InputParam(name="num_images_per_prompt", default=1), ] @property def intermediates_inputs(self) -> List[str]: return [ - InputParam("batch_size", description="batch size for generated image_latents, if not provided, same number of images as input"), InputParam("dtype"), InputParam("preprocess_kwargs")] @@ -384,13 +382,10 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.device = pipeline._execution_device data.dtype = data.dtype if data.dtype is not None else pipeline.vae.dtype - data.image = pipeline.image_processor.preprocess(data.image, height=data.height, width=data.width, **data.preprocess_kwargs) data.image = data.image.to(device=data.device, dtype=data.dtype) - data.batch_size = data.batch_size if data.batch_size is not None else data.image.shape[0] - - data.batch_size = data.batch_size * data.num_images_per_prompt + data.batch_size = data.image.shape[0] # if generator is a list, make sure the length of it matches the length of images (both should be batch_size) if isinstance(data.generator, list) and len(data.generator) != data.batch_size: @@ -399,32 +394,77 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: f" size of {data.batch_size}. Make sure the batch size matches the length of the generators." ) - elif isinstance(data.generator, list): - if data.image.shape[0] < data.batch_size and data.batch_size % data.image.shape[0] == 0: - data.image = torch.cat([data.image] * (data.batch_size // data.image.shape[0]), dim=0) - elif data.image.shape[0] < data.batch_size and data.batch_size % data.image.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {data.image.shape[0]} to effective batch_size {data.batch_size} " - ) data.image_latents = pipeline._encode_vae_image(image=data.image, generator=data.generator) - - if data.batch_size > data.image_latents.shape[0] and data.batch_size % data.image_latents.shape[0] == 0: - # expand latents for batch_size - data.additional_image_per_prompt = data.batch_size // data.image_latents.shape[0] - data.image_latents = torch.cat([data.image_latents] * additional_image_per_prompt, dim=0) - elif data.batch_size > data.image_latents.shape[0] and data.batch_size % data.image_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {data.image_latents.shape[0]} to {data.batch_size} text prompts." - ) - else: - data.image_latents = torch.cat([data.image_latents], dim=0) - + self.add_block_state(state, data) return pipeline, state +class StableDiffusionXLLoraStep(PipelineBlock): + expected_components = ["text_encoder", "text_encoder_2", "unet"] + model_name = "stable-diffusion-xl" + + def __init__(self): + super().__init__() + self.components["text_encoder"] = None + self.components["text_encoder_2"] = None + self.components["unet"] = None + + @torch.no_grad() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + raise EnvironmentError("StableDiffusionXLLoraStep is desgined to be used to load lora weights, __call__ is not implemented") + + +class StableDiffusionXLIPAdapterStep(PipelineBlock): + expected_components = ["image_encoder", "feature_extractor", "unet"] + model_name = "stable-diffusion-xl" + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("ip_adapter_image", required=True), + InputParam("guidance_scale", default=5.0), + ] + + @property + def intermediates_outputs(self) -> List[str]: + return [OutputParam("ip_adapter_embeds"), OutputParam("negative_ip_adapter_embeds")] + + def __init__(self): + super().__init__() + self.components["image_encoder"] = None + self.components["feature_extractor"] = None + self.components["unet"] = None + + + @torch.no_grad() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + data = self.get_block_state(state) + + data.do_classifier_free_guidance = data.guidance_scale > 1.0 + data.device = pipeline._execution_device + + data.ip_adapter_embeds = pipeline.prepare_ip_adapter_image_embeds( + ip_adapter_image=data.ip_adapter_image, + ip_adapter_image_embeds=None, + device=data.device, + num_images_per_prompt=1, + do_classifier_free_guidance=data.do_classifier_free_guidance, + ) + if data.do_classifier_free_guidance: + data.negative_ip_adapter_embeds = [] + for i, image_embeds in enumerate(data.ip_adapter_embeds): + negative_image_embeds, image_embeds = image_embeds.chunk(2) + data.negative_ip_adapter_embeds.append(negative_image_embeds) + data.ip_adapter_embeds[i] = image_embeds + + self.add_block_state(state, data) + return pipeline, state + + + class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): expected_components = ["scheduler"] model_name = "stable-diffusion-xl" @@ -544,7 +584,6 @@ def inputs(self) -> List[InputParam]: InputParam("height"), InputParam("width"), InputParam("generator"), - InputParam("num_images_per_prompt", default=1), InputParam("image", required=True), InputParam("mask_image", required=True), InputParam("padding_mask_crop"), @@ -552,7 +591,7 @@ def inputs(self) -> List[InputParam]: @property def intermediates_inputs(self) -> List[InputParam]: - return [InputParam("batch_size"), InputParam("dtype")] + return [InputParam("dtype")] @property def intermediates_outputs(self) -> List[OutputParam]: @@ -585,23 +624,9 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin data.mask = pipeline.mask_processor.preprocess(data.mask_image, height=data.height, width=data.width, resize_mode=data.resize_mode, crops_coords=data.crops_coords) data.masked_image = data.image * (data.mask < 0.5) - data.batch_size = data.batch_size if data.batch_size is not None else data.image.shape[0] - - data.batch_size = data.batch_size * data.num_images_per_prompt + data.batch_size = data.image.shape[0] data.image = data.image.to(device=data.device, dtype=data.dtype) data.image_latents = pipeline._encode_vae_image(image=data.image, generator=data.generator) - - if data.batch_size > data.image_latents.shape[0] and data.batch_size % data.image_latents.shape[0] == 0: - # expand latents for batch_size - data.additional_image_per_prompt = data.batch_size // data.image_latents.shape[0] - data.image_latents = torch.cat([data.image_latents] * data.additional_image_per_prompt, dim=0) - elif data.batch_size > data.image_latents.shape[0] and data.batch_size % data.image_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {data.image_latents.shape[0]} to {data.batch_size} text prompts." - ) - else: - data.image_latents = torch.cat([data.image_latents], dim=0) - # 7. Prepare mask latent variables data.mask, data.masked_image_latents = pipeline.prepare_mask_latents( @@ -655,76 +680,53 @@ def __init__(self): @torch.no_grad() def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - latents = state.get_input("latents") - num_images_per_prompt = state.get_input("num_images_per_prompt") - generator = state.get_input("generator") - # image to image only - denoising_start = state.get_input("denoising_start") - - # inpaint only - strength = state.get_input("strength") - - # image to image only - latent_timestep = state.get_intermediate("latent_timestep") - - # YiYi Notes: mask and masked_image_latents should be intermediate outputs from StableDiffusionXLPrepareMaskedImageLatentsStep - image_latents = state.get_intermediate("image_latents") - mask = state.get_intermediate("mask") - masked_image_latents = state.get_intermediate("masked_image_latents") - - - batch_size = state.get_intermediate("batch_size") - dtype = state.get_intermediate("dtype") + data = self.get_block_state(state) - if dtype is None: - dtype = pipeline.vae.dtype - device = pipeline._execution_device + data.dtype = data.dtype if data.dtype is not None else pipeline.vae.dtype + data.device = pipeline._execution_device - is_strength_max = strength == 1.0 + data.is_strength_max = data.strength == 1.0 # for non-inpainting specific unet, we do not need masked_image_latents if hasattr(pipeline,"unet") and pipeline.unet is not None: if pipeline.unet.config.in_channels == 4: - masked_image_latents = None + data.masked_image_latents = None - add_noise = True if denoising_start is None else False + data.add_noise = True if data.denoising_start is None else False - height = image_latents.shape[-2] * pipeline.vae_scale_factor - width = image_latents.shape[-1] * pipeline.vae_scale_factor + data.height = data.image_latents.shape[-2] * pipeline.vae_scale_factor + data.width = data.image_latents.shape[-1] * pipeline.vae_scale_factor - latents, noise = pipeline.prepare_latents_inpaint( - batch_size * num_images_per_prompt, + data.latents, data.noise = pipeline.prepare_latents_inpaint( + data.batch_size * data.num_images_per_prompt, pipeline.num_channels_latents, - height, - width, - dtype, - device, - generator, - latents, - image=image_latents, - timestep=latent_timestep, - is_strength_max=is_strength_max, - add_noise=add_noise, + data.height, + data.width, + data.dtype, + data.device, + data.generator, + data.latents, + image=data.image_latents, + timestep=data.latent_timestep, + is_strength_max=data.is_strength_max, + add_noise=data.add_noise, return_noise=True, return_image_latents=False, ) - # 7. Prepare mask latent variables - mask, masked_image_latents = pipeline.prepare_mask_latents( - mask, - masked_image_latents, - batch_size * num_images_per_prompt, - height, - width, - dtype, - device, - generator, + # 7. Prepare mask latent variables + data.mask, data.masked_image_latents = pipeline.prepare_mask_latents( + data.mask, + data.masked_image_latents, + data.batch_size * data.num_images_per_prompt, + data.height, + data.width, + data.dtype, + data.device, + data.generator, ) - state.add_intermediate("latents", latents) - state.add_intermediate("mask", mask) - state.add_intermediate("masked_image_latents", masked_image_latents) - state.add_intermediate("noise", noise) + self.add_block_state(state, data) return pipeline, state @@ -747,7 +749,7 @@ def intermediates_inputs(self) -> List[str]: return [ InputParam("latent_timestep", required=True), InputParam("image_latents", required=True), - InputParam("batch_size"), + InputParam("batch_size", required=True), InputParam("dtype")] @property @@ -760,39 +762,24 @@ def __init__(self): @torch.no_grad() def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - latents = state.get_input("latents") - num_images_per_prompt = state.get_input("num_images_per_prompt") - generator = state.get_input("generator") - - # image to image only - denoising_start = state.get_input("denoising_start") - - batch_size = state.get_intermediate("batch_size") - dtype = state.get_intermediate("dtype") - # image to image only - latent_timestep = state.get_intermediate("latent_timestep") - image_latents = state.get_intermediate("image_latents") - - if dtype is None: - dtype = pipeline.vae.dtype - if batch_size is None: - batch_size = image_latents.shape[0] - - device = pipeline._execution_device - add_noise = True if denoising_start is None else False - if latents is None: - latents = pipeline.prepare_latents_img2img( - image_latents, - latent_timestep, - batch_size, - num_images_per_prompt, - dtype, - device, - generator, - add_noise, + data = self.get_block_state(state) + + data.dtype = data.dtype if data.dtype is not None else pipeline.vae.dtype + data.device = pipeline._execution_device + data.add_noise = True if data.denoising_start is None else False + if data.latents is None: + data.latents = pipeline.prepare_latents_img2img( + data.image_latents, + data.latent_timestep, + data.batch_size, + data.num_images_per_prompt, + data.dtype, + data.device, + data.generator, + data.add_noise, ) - state.add_intermediate("latents", latents) + self.add_block_state(state, data) return pipeline, state @@ -824,51 +811,43 @@ def __init__(self): self.components["scheduler"] = None @staticmethod - def check_inputs(pipeline, height, width): + def check_inputs(pipeline, data): if ( - height is not None - and height % pipeline.vae_scale_factor != 0 - or width is not None - and width % pipeline.vae_scale_factor != 0 + data.height is not None + and data.height % pipeline.vae_scale_factor != 0 + or data.width is not None + and data.width % pipeline.vae_scale_factor != 0 ): raise ValueError( - f"`height` and `width` have to be divisible by {pipeline.vae_scale_factor} but are {height} and {width}." + f"`height` and `width` have to be divisible by {pipeline.vae_scale_factor} but are {data.height} and {data.width}." ) @torch.no_grad() def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - latents = state.get_input("latents") - num_images_per_prompt = state.get_input("num_images_per_prompt") - generator = state.get_input("generator") - - # text to image only - height = state.get_input("height") - width = state.get_input("width") + data = self.get_block_state(state) - batch_size = state.get_intermediate("batch_size") - dtype = state.get_intermediate("dtype") - if dtype is None: - dtype = pipeline.vae.dtype + if data.dtype is None: + data.dtype = pipeline.vae.dtype - device = pipeline._execution_device + data.device = pipeline._execution_device - self.check_inputs(pipeline, height, width) + self.check_inputs(pipeline, data) - height = height or pipeline.default_sample_size * pipeline.vae_scale_factor - width = width or pipeline.default_sample_size * pipeline.vae_scale_factor - num_channels_latents = pipeline.num_channels_latents - latents = pipeline.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - dtype, - device, - generator, - latents, + data.height = data.height or pipeline.default_sample_size * pipeline.vae_scale_factor + data.width = data.width or pipeline.default_sample_size * pipeline.vae_scale_factor + data.num_channels_latents = pipeline.num_channels_latents + data.latents = pipeline.prepare_latents( + data.batch_size * data.num_images_per_prompt, + data.num_channels_latents, + data.height, + data.width, + data.dtype, + data.device, + data.generator, + data.latents, ) - state.add_intermediate("latents", latents) + self.add_block_state(state, data) return pipeline, state @@ -894,7 +873,11 @@ def inputs(self) -> List[Tuple[str, Any]]: @property def intermediates_inputs(self) -> List[InputParam]: - return [InputParam("latents", required=True), InputParam("pooled_prompt_embeds", required=True)] + return [ + InputParam("latents", required=True), + InputParam("pooled_prompt_embeds", required=True), + InputParam("batch_size", required=True), + ] @property def intermediates_outputs(self) -> List[OutputParam]: @@ -906,78 +889,53 @@ def __init__(self): @torch.no_grad() def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - original_size = state.get_input("original_size") - target_size = state.get_input("target_size") - negative_original_size = state.get_input("negative_original_size") - negative_target_size = state.get_input("negative_target_size") - crops_coords_top_left = state.get_input("crops_coords_top_left") - negative_crops_coords_top_left = state.get_input("negative_crops_coords_top_left") - num_images_per_prompt = state.get_input("num_images_per_prompt") - guidance_scale = state.get_input("guidance_scale") - - # image to image only - aesthetic_score = state.get_input("aesthetic_score") - negative_aesthetic_score = state.get_input("negative_aesthetic_score") - - latents = state.get_intermediate("latents") - pooled_prompt_embeds = state.get_intermediate("pooled_prompt_embeds") - - device = pipeline._execution_device - - batch_size = latents.shape[0] - - if hasattr(pipeline, "vae") and pipeline.vae is not None: - vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1) - else: - vae_scale_factor = 8 - - height, width = latents.shape[-2:] - height = height * vae_scale_factor - width = width * vae_scale_factor - - original_size = original_size or (height, width) - target_size = target_size or (height, width) + data = self.get_block_state(state) + data.device = pipeline._execution_device - if hasattr(pipeline, "text_encoder_2") and pipeline.text_encoder_2 is not None: - text_encoder_projection_dim = pipeline.text_encoder_2.config.projection_dim - else: - text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) - - if negative_original_size is None: - negative_original_size = original_size - if negative_target_size is None: - negative_target_size = target_size - - add_time_ids, negative_add_time_ids = pipeline._get_add_time_ids_img2img( - original_size, - crops_coords_top_left, - target_size, - aesthetic_score, - negative_aesthetic_score, - negative_original_size, - negative_crops_coords_top_left, - negative_target_size, - dtype=pooled_prompt_embeds.dtype, - text_encoder_projection_dim=text_encoder_projection_dim, + data.vae_scale_factor = pipeline.vae_scale_factor + + data.height, data.width = data.latents.shape[-2:] + data.height = data.height * data.vae_scale_factor + data.width = data.width * data.vae_scale_factor + + data.original_size = data.original_size or (data.height, data.width) + data.target_size = data.target_size or (data.height, data.width) + + data.text_encoder_projection_dim = int(data.pooled_prompt_embeds.shape[-1]) + + if data.negative_original_size is None: + data.negative_original_size = data.original_size + if data.negative_target_size is None: + data.negative_target_size = data.target_size + + data.add_time_ids, data.negative_add_time_ids = pipeline._get_add_time_ids_img2img( + data.original_size, + data.crops_coords_top_left, + data.target_size, + data.aesthetic_score, + data.negative_aesthetic_score, + data.negative_original_size, + data.negative_crops_coords_top_left, + data.negative_target_size, + dtype=data.pooled_prompt_embeds.dtype, + text_encoder_projection_dim=data.text_encoder_projection_dim, ) - add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1).to(device=device) - negative_add_time_ids = negative_add_time_ids.repeat(batch_size * num_images_per_prompt, 1).to(device=device) + data.add_time_ids = data.add_time_ids.repeat(data.batch_size * data.num_images_per_prompt, 1).to(device=data.device) + data.negative_add_time_ids = data.negative_add_time_ids.repeat(data.batch_size * data.num_images_per_prompt, 1).to(device=data.device) # Optionally get Guidance Scale Embedding for LCM - timestep_cond = None + data.timestep_cond = None if ( hasattr(pipeline, "unet") and pipeline.unet is not None and pipeline.unet.config.time_cond_proj_dim is not None ): - guidance_scale_tensor = torch.tensor(guidance_scale - 1).repeat(batch_size * num_images_per_prompt) - timestep_cond = pipeline.get_guidance_scale_embedding( - guidance_scale_tensor, embedding_dim=pipeline.unet.config.time_cond_proj_dim - ).to(device=device, dtype=latents.dtype) - - state.add_intermediate("add_time_ids", add_time_ids) - state.add_intermediate("negative_add_time_ids", negative_add_time_ids) - state.add_intermediate("timestep_cond", timestep_cond) + data.guidance_scale_tensor = torch.tensor(data.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt) + data.timestep_cond = pipeline.get_guidance_scale_embedding( + data.guidance_scale_tensor, embedding_dim=pipeline.unet.config.time_cond_proj_dim + ).to(device=data.device, dtype=data.latents.dtype) + + self.add_block_state(state, data) return pipeline, state @@ -999,7 +957,11 @@ def inputs(self) -> List[Tuple[str, Any]]: @property def intermediates_inputs(self) -> List[InputParam]: - return [InputParam("latents", required=True), InputParam("pooled_prompt_embeds", required=True)] + return [ + InputParam("latents", required=True), + InputParam("pooled_prompt_embeds", required=True), + InputParam("batch_size", required=True), + ] @property def intermediates_outputs(self) -> List[OutputParam]: @@ -1007,71 +969,52 @@ def intermediates_outputs(self) -> List[OutputParam]: @torch.no_grad() def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - original_size = state.get_input("original_size") - target_size = state.get_input("target_size") - negative_original_size = state.get_input("negative_original_size") - negative_target_size = state.get_input("negative_target_size") - crops_coords_top_left = state.get_input("crops_coords_top_left") - negative_crops_coords_top_left = state.get_input("negative_crops_coords_top_left") - num_images_per_prompt = state.get_input("num_images_per_prompt") - guidance_scale = state.get_input("guidance_scale") - device = state.get_input("device") - - latents = state.get_intermediate("latents") - pooled_prompt_embeds = state.get_intermediate("pooled_prompt_embeds") - - device = pipeline._execution_device + data = self.get_block_state(state) + data.device = pipeline._execution_device - batch_size = latents.shape[0] + data.height, data.width = data.latents.shape[-2:] + data.height = data.height * pipeline.vae_scale_factor + data.width = data.width * pipeline.vae_scale_factor - height, width = latents.shape[-2:] - height = height * pipeline.vae_scale_factor - width = width * pipeline.vae_scale_factor + data.original_size = data.original_size or (data.height, data.width) + data.target_size = data.target_size or (data.height, data.width) - original_size = original_size or (height, width) - target_size = target_size or (height, width) + data.text_encoder_projection_dim = int(data.pooled_prompt_embeds.shape[-1]) - if hasattr(pipeline, "text_encoder_2") and pipeline.text_encoder_2 is not None: - text_encoder_projection_dim = pipeline.text_encoder_2.config.projection_dim - else: - text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) - - add_time_ids = pipeline._get_add_time_ids( - original_size, - crops_coords_top_left, - target_size, - pooled_prompt_embeds.dtype, - text_encoder_projection_dim=text_encoder_projection_dim, + data.add_time_ids = pipeline._get_add_time_ids( + data.original_size, + data.crops_coords_top_left, + data.target_size, + data.pooled_prompt_embeds.dtype, + text_encoder_projection_dim=data.text_encoder_projection_dim, ) - add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1).to(device=device) - - if negative_original_size is not None and negative_target_size is not None: - negative_add_time_ids = pipeline._get_add_time_ids( - negative_original_size, - negative_crops_coords_top_left, - negative_target_size, - pooled_prompt_embeds.dtype, - text_encoder_projection_dim=text_encoder_projection_dim, + data.add_time_ids = data.add_time_ids.repeat(data.batch_size * data.num_images_per_prompt, 1).to(device=data.device) + + if data.negative_original_size is not None and data.negative_target_size is not None: + data.negative_add_time_ids = pipeline._get_add_time_ids( + data.negative_original_size, + data.negative_crops_coords_top_left, + data.negative_target_size, + data.pooled_prompt_embeds.dtype, + text_encoder_projection_dim=data.text_encoder_projection_dim, ) else: - negative_add_time_ids = add_time_ids - negative_add_time_ids = negative_add_time_ids.repeat(batch_size * num_images_per_prompt, 1).to(device=device) + data.negative_add_time_ids = data.add_time_ids + data.negative_add_time_ids = data.negative_add_time_ids.repeat(data.batch_size * data.num_images_per_prompt, 1).to(device=data.device) # Optionally get Guidance Scale Embedding for LCM - timestep_cond = None + data.timestep_cond = None if ( hasattr(pipeline, "unet") and pipeline.unet is not None and pipeline.unet.config.time_cond_proj_dim is not None ): - guidance_scale_tensor = torch.tensor(guidance_scale - 1).repeat(batch_size * num_images_per_prompt) - timestep_cond = pipeline.get_guidance_scale_embedding( - guidance_scale_tensor, embedding_dim=pipeline.unet.config.time_cond_proj_dim - ).to(device=device, dtype=latents.dtype) - - state.add_intermediate("add_time_ids", add_time_ids) - state.add_intermediate("negative_add_time_ids", negative_add_time_ids) - state.add_intermediate("timestep_cond", timestep_cond) + data.guidance_scale_tensor = torch.tensor(data.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt) + data.timestep_cond = pipeline.get_guidance_scale_embedding( + data.guidance_scale_tensor, embedding_dim=pipeline.unet.config.time_cond_proj_dim + ).to(device=data.device, dtype=data.latents.dtype) + + self.add_block_state(state, data) return pipeline, state @@ -1094,19 +1037,22 @@ def inputs(self) -> List[Tuple[str, Any]]: def intermediates_inputs(self) -> List[str]: return [ InputParam("latents", required=True), + InputParam("batch_size", required=True), InputParam("timesteps", required=True), InputParam("num_inference_steps", required=True), InputParam("pooled_prompt_embeds", required=True), - InputParam("negative_pooled_prompt_embeds", required=True), + InputParam("negative_pooled_prompt_embeds"), InputParam("add_time_ids", required=True), - InputParam("negative_add_time_ids", required=True), + InputParam("negative_add_time_ids"), InputParam("prompt_embeds", required=True), - InputParam("negative_prompt_embeds", required=True), + InputParam("negative_prompt_embeds"), InputParam("timestep_cond"), # LCM InputParam("mask"), # inpainting InputParam("masked_image_latents"), # inpainting InputParam("noise"), # inpainting InputParam("image_latents"), # inpainting + InputParam("ip_adapter_embeds"), # ip adapter + InputParam("negative_ip_adapter_embeds"), # ip adapter ] @property @@ -1119,43 +1065,16 @@ def __init__(self): self.components["scheduler"] = None self.components["unet"] = None - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - guidance_scale = state.get_input("guidance_scale") - guidance_rescale = state.get_input("guidance_rescale") - - cross_attention_kwargs = state.get_input("cross_attention_kwargs") - generator = state.get_input("generator") - eta = state.get_input("eta") - guider_kwargs = state.get_input("guider_kwargs") - - batch_size = state.get_intermediate("batch_size") - prompt_embeds = state.get_intermediate("prompt_embeds") - negative_prompt_embeds = state.get_intermediate("negative_prompt_embeds") - pooled_prompt_embeds = state.get_intermediate("pooled_prompt_embeds") - negative_pooled_prompt_embeds = state.get_intermediate("negative_pooled_prompt_embeds") - add_time_ids = state.get_intermediate("add_time_ids") - negative_add_time_ids = state.get_intermediate("negative_add_time_ids") - - latents = state.get_intermediate("latents") - - #LCM - timestep_cond = state.get_intermediate("timestep_cond") - - # inpainting - mask = state.get_intermediate("mask") - masked_image_latents = state.get_intermediate("masked_image_latents") - noise = state.get_intermediate("noise") - image_latents = state.get_intermediate("image_latents") + def check_inputs(self, pipeline, data): num_channels_unet = pipeline.unet.config.in_channels if num_channels_unet == 9: # default case for runwayml/stable-diffusion-inpainting - if mask is None or masked_image_latents is None: + if data.mask is None or data.masked_image_latents is None: raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") - num_channels_latents = latents.shape[1] - num_channels_mask = mask.shape[1] - num_channels_masked_image = masked_image_latents.shape[1] + num_channels_latents = data.latents.shape[1] + num_channels_mask = data.mask.shape[1] + num_channels_masked_image = data.masked_image_latents.shape[1] if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: raise ValueError( f"Incorrect configuration settings! The config of `pipeline.unet`: {pipeline.unet.config} expects" @@ -1165,98 +1084,107 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: " `pipeline.unet` or your `mask_image` or `image` input." ) - timesteps = state.get_intermediate("timesteps") - num_inference_steps = state.get_intermediate("num_inference_steps") - disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False + + @torch.no_grad() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + + data = self.get_block_state(state) + self.check_inputs(pipeline, data) + + data.num_channels_unet = pipeline.unet.config.in_channels + data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False # adding default guider arguments: do_classifier_free_guidance, guidance_scale, guidance_rescale - guider_kwargs = guider_kwargs or {} - guider_kwargs = { - **guider_kwargs, - "disable_guidance": disable_guidance, - "guidance_scale": guidance_scale, - "guidance_rescale": guidance_rescale, - "batch_size": batch_size, + data.guider_kwargs = data.guider_kwargs or {} + data.guider_kwargs = { + **data.guider_kwargs, + "disable_guidance": data.disable_guidance, + "guidance_scale": data.guidance_scale, + "guidance_rescale": data.guidance_rescale, + "batch_size": data.batch_size, } - pipeline.guider.set_guider(pipeline, guider_kwargs) + pipeline.guider.set_guider(pipeline, data.guider_kwargs) # Prepare conditional inputs using the guider - prompt_embeds = pipeline.guider.prepare_input( - prompt_embeds, - negative_prompt_embeds, + data.prompt_embeds = pipeline.guider.prepare_input( + data.prompt_embeds, + data.negative_prompt_embeds, ) - add_time_ids = pipeline.guider.prepare_input( - add_time_ids, - negative_add_time_ids, + data.add_time_ids = pipeline.guider.prepare_input( + data.add_time_ids, + data.negative_add_time_ids, ) - pooled_prompt_embeds = pipeline.guider.prepare_input( - pooled_prompt_embeds, - negative_pooled_prompt_embeds, + data.pooled_prompt_embeds = pipeline.guider.prepare_input( + data.pooled_prompt_embeds, + data.negative_pooled_prompt_embeds, ) - if num_channels_unet == 9: - mask = pipeline.guider.prepare_input(mask, mask) - masked_image_latents = pipeline.guider.prepare_input(masked_image_latents, masked_image_latents) + if data.num_channels_unet == 9: + data.mask = pipeline.guider.prepare_input(data.mask, data.mask) + data.masked_image_latents = pipeline.guider.prepare_input(data.masked_image_latents, data.masked_image_latents) - added_cond_kwargs = { - "text_embeds": pooled_prompt_embeds, - "time_ids": add_time_ids, + data.added_cond_kwargs = { + "text_embeds": data.pooled_prompt_embeds, + "time_ids": data.add_time_ids, } + if data.ip_adapter_embeds is not None: + data.ip_adapter_embeds = pipeline.guider.prepare_input(data.ip_adapter_embeds, data.negative_ip_adapter_embeds) + data.added_cond_kwargs["image_embeds"] = data.ip_adapter_embeds + # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta) - num_warmup_steps = max(len(timesteps) - num_inference_steps * pipeline.scheduler.order, 0) + data.extra_step_kwargs = pipeline.prepare_extra_step_kwargs(data.generator, data.eta) + data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0) - with pipeline.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): + with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: + for i, t in enumerate(data.timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = pipeline.guider.prepare_input(latents, latents) - latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t) + data.latent_model_input = pipeline.guider.prepare_input(data.latents, data.latents) + data.latent_model_input = pipeline.scheduler.scale_model_input(data.latent_model_input, t) # inpainting - if num_channels_unet == 9: - latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + if data.num_channels_unet == 9: + data.latent_model_input = torch.cat([data.latent_model_input, data.mask, data.masked_image_latents], dim=1) # predict the noise residual - noise_pred = pipeline.unet( - latent_model_input, + data.noise_pred = pipeline.unet( + data.latent_model_input, t, - encoder_hidden_states=prompt_embeds, - timestep_cond=timestep_cond, - cross_attention_kwargs=cross_attention_kwargs, - added_cond_kwargs=added_cond_kwargs, + encoder_hidden_states=data.prompt_embeds, + timestep_cond=data.timestep_cond, + cross_attention_kwargs=data.cross_attention_kwargs, + added_cond_kwargs=data.added_cond_kwargs, return_dict=False, )[0] # perform guidance - noise_pred = pipeline.guider.apply_guidance( - noise_pred, + data.noise_pred = pipeline.guider.apply_guidance( + data.noise_pred, timestep=t, - latents=latents, + latents=data.latents, ) # compute the previous noisy sample x_t -> x_t-1 - latents_dtype = latents.dtype - latents = pipeline.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - if latents.dtype != latents_dtype: + data.latents_dtype = data.latents.dtype + data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0] + if data.latents.dtype != data.latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - latents = latents.to(latents_dtype) + data.latents = data.latents.to(data.latents_dtype) - if num_channels_unet == 4 and mask is not None and image_latents is not None: - init_mask = pipeline.guider._maybe_split_prepared_input(mask)[0] - init_latents_proper = image_latents - if i < len(timesteps) - 1: - noise_timestep = timesteps[i + 1] - init_latents_proper = pipeline.scheduler.add_noise( - init_latents_proper, noise, torch.tensor([noise_timestep]) + if data.num_channels_unet == 4 and data.mask is not None and data.image_latents is not None: + data.init_latents_proper = data.image_latents + if i < len(data.timesteps) - 1: + data.noise_timestep = data.timesteps[i + 1] + data.init_latents_proper = pipeline.scheduler.add_noise( + data.init_latents_proper, data.noise, torch.tensor([data.noise_timestep]) ) - latents = (1 - init_mask) * init_latents_proper + init_mask * latents + data.latents = (1 - data.mask) * data.init_latents_proper + data.mask * data.latents - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): + if i == len(data.timesteps) - 1 or ((i + 1) > data.num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): progress_bar.update() pipeline.guider.reset_guider(pipeline) - state.add_intermediate("latents", latents) + self.add_block_state(state, data) return pipeline, state @@ -1286,20 +1214,23 @@ def inputs(self) -> List[Tuple[str, Any]]: def intermediates_inputs(self) -> List[str]: return [ InputParam("latents", required=True), + InputParam("batch_size", required=True), InputParam("timesteps", required=True), InputParam("num_inference_steps", required=True), InputParam("prompt_embeds", required=True), - InputParam("negative_prompt_embeds", required=True), + InputParam("negative_prompt_embeds"), InputParam("add_time_ids", required=True), - InputParam("negative_add_time_ids", required=True), + InputParam("negative_add_time_ids"), InputParam("pooled_prompt_embeds", required=True), - InputParam("negative_pooled_prompt_embeds", required=True), + InputParam("negative_pooled_prompt_embeds"), InputParam("timestep_cond"), # LCM InputParam("mask"), # inpainting InputParam("masked_image_latents"), # inpainting InputParam("noise"), # inpainting InputParam("image_latents"), # inpainting InputParam("crops_coords"), # inpainting + InputParam("ip_adapter_embeds"), # ip adapter + InputParam("negative_ip_adapter_embeds"), # ip adapter ] @property @@ -1316,49 +1247,16 @@ def __init__(self): control_image_processor = VaeImageProcessor(do_convert_rgb=True, do_normalize=False) self.auxiliaries["control_image_processor"] = control_image_processor - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - guidance_scale = state.get_input("guidance_scale") - guidance_rescale = state.get_input("guidance_rescale") - cross_attention_kwargs = state.get_input("cross_attention_kwargs") - guider_kwargs = state.get_input("guider_kwargs") - generator = state.get_input("generator") - eta = state.get_input("eta") - num_images_per_prompt = state.get_input("num_images_per_prompt") - # controlnet-specific inputs - control_image = state.get_input("control_image") - control_guidance_start = state.get_input("control_guidance_start") - control_guidance_end = state.get_input("control_guidance_end") - controlnet_conditioning_scale = state.get_input("controlnet_conditioning_scale") - guess_mode = state.get_input("guess_mode") - - latents = state.get_intermediate("latents") - timesteps = state.get_intermediate("timesteps") - num_inference_steps = state.get_intermediate("num_inference_steps") - - prompt_embeds = state.get_intermediate("prompt_embeds") - negative_prompt_embeds = state.get_intermediate("negative_prompt_embeds") - pooled_prompt_embeds = state.get_intermediate("pooled_prompt_embeds") - negative_pooled_prompt_embeds = state.get_intermediate("negative_pooled_prompt_embeds") - add_time_ids = state.get_intermediate("add_time_ids") - negative_add_time_ids = state.get_intermediate("negative_add_time_ids") - - timestep_cond = state.get_intermediate("timestep_cond") - - # inpainting - mask = state.get_intermediate("mask") - masked_image_latents = state.get_intermediate("masked_image_latents") - noise = state.get_intermediate("noise") - image_latents = state.get_intermediate("image_latents") - crops_coords = state.get_intermediate("crops_coords") + def check_inputs(self, pipeline, data): + num_channels_unet = pipeline.unet.config.in_channels if num_channels_unet == 9: # default case for runwayml/stable-diffusion-inpainting - if mask is None or masked_image_latents is None: + if data.mask is None or data.masked_image_latents is None: raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") - num_channels_latents = latents.shape[1] - num_channels_mask = mask.shape[1] - num_channels_masked_image = masked_image_latents.shape[1] + num_channels_latents = data.latents.shape[1] + num_channels_mask = data.mask.shape[1] + num_channels_masked_image = data.masked_image_latents.shape[1] if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: raise ValueError( f"Incorrect configuration settings! The config of `pipeline.unet`: {pipeline.unet.config} expects" @@ -1368,216 +1266,236 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: " `pipeline.unet` or your `mask_image` or `image` input." ) + @torch.no_grad() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + + data = self.get_block_state(state) + self.check_inputs(pipeline, data) + + data.num_channels_unet = pipeline.unet.config.in_channels + + # (1) prepare controlnet inputs + + data.device = pipeline._execution_device - device = pipeline._execution_device - batch_size = latents.shape[0] - - height, width = latents.shape[-2:] - height = height * pipeline.vae_scale_factor - width = width * pipeline.vae_scale_factor + data.height, data.width = data.latents.shape[-2:] + data.height = data.height * pipeline.vae_scale_factor + data.width = data.width * pipeline.vae_scale_factor - # prepare controlnet inputs controlnet = pipeline.controlnet._orig_mod if is_compiled_module(pipeline.controlnet) else pipeline.controlnet - # align format for control guidance - if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): - control_guidance_start = len(control_guidance_end) * [control_guidance_start] - elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): - control_guidance_end = len(control_guidance_start) * [control_guidance_end] - elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + # (1.1) + # control_guidance_start/control_guidance_end (align format) + if not isinstance(data.control_guidance_start, list) and isinstance(data.control_guidance_end, list): + data.control_guidance_start = len(data.control_guidance_end) * [data.control_guidance_start] + elif not isinstance(data.control_guidance_end, list) and isinstance(data.control_guidance_start, list): + data.control_guidance_end = len(data.control_guidance_start) * [data.control_guidance_end] + elif not isinstance(data.control_guidance_start, list) and not isinstance(data.control_guidance_end, list): mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 - control_guidance_start, control_guidance_end = ( - mult * [control_guidance_start], - mult * [control_guidance_end], + data.control_guidance_start, data.control_guidance_end = ( + mult * [data.control_guidance_start], + mult * [data.control_guidance_end], ) - if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): - controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + # (1.2) + # controlnet_conditioning_scale (align format) + if isinstance(controlnet, MultiControlNetModel) and isinstance(data.controlnet_conditioning_scale, float): + data.controlnet_conditioning_scale = [data.controlnet_conditioning_scale] * len(controlnet.nets) - global_pool_conditions = ( + # (1.3) + # global_pool_conditions + data.global_pool_conditions = ( controlnet.config.global_pool_conditions if isinstance(controlnet, ControlNetModel) else controlnet.nets[0].config.global_pool_conditions ) - guess_mode = guess_mode or global_pool_conditions + # (1.4) + # guess_mode + data.guess_mode = data.guess_mode or data.global_pool_conditions - # 4. Prepare image + # (1.5) + # control_image if isinstance(controlnet, ControlNetModel): - control_image = pipeline.prepare_control_image( - image=control_image, - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, + data.control_image = pipeline.prepare_control_image( + image=data.control_image, + width=data.width, + height=data.height, + batch_size=data.batch_size * data.num_images_per_prompt, + num_images_per_prompt=data.num_images_per_prompt, + device=data.device, dtype=controlnet.dtype, - crops_coords=crops_coords, + crops_coords=data.crops_coords, ) elif isinstance(controlnet, MultiControlNetModel): control_images = [] - for control_image_ in control_image: + for control_image_ in data.control_image: control_image = pipeline.prepare_control_image( image=control_image_, - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, + width=data.width, + height=data.height, + batch_size=data.batch_size * data.num_images_per_prompt, + num_images_per_prompt=data.num_images_per_prompt, + device=data.device, dtype=controlnet.dtype, crops_coords=crops_coords, ) control_images.append(control_image) - control_image = control_images + data.control_image = control_images else: assert False - # 7.1 Create tensor stating which controlnets to keep - controlnet_keep = [] - for i in range(len(timesteps)): + # (1.6) + # controlnet_keep + data.controlnet_keep = [] + for i in range(len(data.timesteps)): keeps = [ - 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) - for s, e in zip(control_guidance_start, control_guidance_end) + 1.0 - float(i / len(data.timesteps) < s or (i + 1) / len(data.timesteps) > e) + for s, e in zip(data.control_guidance_start, data.control_guidance_end) ] - controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + data.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) - # Prepare conditional inputs for unet using the guider + # (2) Prepare conditional inputs for unet using the guider # adding default guider arguments: disable_guidance, guidance_scale, guidance_rescale - disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False - guider_kwargs = guider_kwargs or {} - guider_kwargs = { - **guider_kwargs, - "disable_guidance": disable_guidance, - "guidance_scale": guidance_scale, - "guidance_rescale": guidance_rescale, - "batch_size": batch_size, + data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False + data.guider_kwargs = data.guider_kwargs or {} + data.guider_kwargs = { + **data.guider_kwargs, + "disable_guidance": data.disable_guidance, + "guidance_scale": data.guidance_scale, + "guidance_rescale": data.guidance_rescale, + "batch_size": data.batch_size, } - pipeline.guider.set_guider(pipeline, guider_kwargs) - prompt_embeds = pipeline.guider.prepare_input( - prompt_embeds, - negative_prompt_embeds, + pipeline.guider.set_guider(pipeline, data.guider_kwargs) + data.prompt_embeds = pipeline.guider.prepare_input( + data.prompt_embeds, + data.negative_prompt_embeds, ) - add_time_ids = pipeline.guider.prepare_input( - add_time_ids, - negative_add_time_ids, + data.add_time_ids = pipeline.guider.prepare_input( + data.add_time_ids, + data.negative_add_time_ids, ) - pooled_prompt_embeds = pipeline.guider.prepare_input( - pooled_prompt_embeds, - negative_pooled_prompt_embeds, + data.pooled_prompt_embeds = pipeline.guider.prepare_input( + data.pooled_prompt_embeds, + data.negative_pooled_prompt_embeds, ) - if num_channels_unet == 9: - mask = pipeline.guider.prepare_input(mask, mask) - masked_image_latents = pipeline.guider.prepare_input(masked_image_latents, masked_image_latents) + if data.num_channels_unet == 9: + data.mask = pipeline.guider.prepare_input(data.mask, data.mask) + data.masked_image_latents = pipeline.guider.prepare_input(data.masked_image_latents, data.masked_image_latents) - added_cond_kwargs = { - "text_embeds": pooled_prompt_embeds, - "time_ids": add_time_ids, + data.added_cond_kwargs = { + "text_embeds": data.pooled_prompt_embeds, + "time_ids": data.add_time_ids, } - # Prepare conditional inputs for controlnet using the guider - controlnet_disable_guidance = True if disable_guidance or guess_mode else False - controlnet_guider_kwargs = guider_kwargs or {} - controlnet_guider_kwargs = { - **controlnet_guider_kwargs, - "disable_guidance": controlnet_disable_guidance, - "guidance_scale": guidance_scale, - "guidance_rescale": guidance_rescale, - "batch_size": batch_size, + if data.ip_adapter_embeds is not None: + data.ip_adapter_embeds = pipeline.guider.prepare_input(data.ip_adapter_embeds, data.negative_ip_adapter_embeds) + data.added_cond_kwargs["image_embeds"] = data.ip_adapter_embeds + + # (3) Prepare conditional inputs for controlnet using the guider + data.controlnet_disable_guidance = True if data.disable_guidance or data.guess_mode else False + data.controlnet_guider_kwargs = data.guider_kwargs or {} + data.controlnet_guider_kwargs = { + **data.controlnet_guider_kwargs, + "disable_guidance": data.controlnet_disable_guidance, + "guidance_scale": data.guidance_scale, + "guidance_rescale": data.guidance_rescale, + "batch_size": data.batch_size, } - pipeline.controlnet_guider.set_guider(pipeline, controlnet_guider_kwargs) - controlnet_prompt_embeds = pipeline.controlnet_guider.prepare_input(prompt_embeds) - controlnet_added_cond_kwargs = { - "text_embeds": pipeline.controlnet_guider.prepare_input(pooled_prompt_embeds), - "time_ids": pipeline.controlnet_guider.prepare_input(add_time_ids), + pipeline.controlnet_guider.set_guider(pipeline, data.controlnet_guider_kwargs) + data.controlnet_prompt_embeds = pipeline.controlnet_guider.prepare_input(data.prompt_embeds) + data.controlnet_added_cond_kwargs = { + "text_embeds": pipeline.controlnet_guider.prepare_input(data.pooled_prompt_embeds), + "time_ids": pipeline.controlnet_guider.prepare_input(data.add_time_ids), } - # controlnet-specific inputs: control_image - control_image = pipeline.controlnet_guider.prepare_input(control_image, control_image) + data.control_image = pipeline.controlnet_guider.prepare_input(data.control_image, data.control_image) - # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta) - num_warmup_steps = max(len(timesteps) - num_inference_steps * pipeline.scheduler.order, 0) + # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + data.extra_step_kwargs = pipeline.prepare_extra_step_kwargs(data.generator, data.eta) + data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0) - with pipeline.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): + # (5) Denoise loop + with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: + for i, t in enumerate(data.timesteps): # prepare latents for unet using the guider - latent_model_input = pipeline.guider.prepare_input(latents, latents) + data.latent_model_input = pipeline.guider.prepare_input(data.latents, data.latents) # prepare latents for controlnet using the guider - control_model_input = pipeline.controlnet_guider.prepare_input(latents, latents) + data.control_model_input = pipeline.controlnet_guider.prepare_input(data.latents, data.latents) - if isinstance(controlnet_keep[i], list): - cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + if isinstance(data.controlnet_keep[i], list): + data.cond_scale = [c * s for c, s in zip(data.controlnet_conditioning_scale, data.controlnet_keep[i])] else: - controlnet_cond_scale = controlnet_conditioning_scale - if isinstance(controlnet_cond_scale, list): - controlnet_cond_scale = controlnet_cond_scale[0] - cond_scale = controlnet_cond_scale * controlnet_keep[i] + data.controlnet_cond_scale = data.controlnet_conditioning_scale + if isinstance(data.controlnet_cond_scale, list): + data.controlnet_cond_scale = data.controlnet_cond_scale[0] + data.cond_scale = data.controlnet_cond_scale * data.controlnet_keep[i] - down_block_res_samples, mid_block_res_sample = pipeline.controlnet( - pipeline.scheduler.scale_model_input(control_model_input, t), + data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet( + pipeline.scheduler.scale_model_input(data.control_model_input, t), t, - encoder_hidden_states=controlnet_prompt_embeds, - controlnet_cond=control_image, - conditioning_scale=cond_scale, - guess_mode=guess_mode, - added_cond_kwargs=controlnet_added_cond_kwargs, + encoder_hidden_states=data.controlnet_prompt_embeds, + controlnet_cond=data.control_image, + conditioning_scale=data.cond_scale, + guess_mode=data.guess_mode, + added_cond_kwargs=data.controlnet_added_cond_kwargs, return_dict=False, ) # when we apply guidance for unet, but not for controlnet: # add 0 to the unconditional batch - down_block_res_samples = pipeline.guider.prepare_input( - down_block_res_samples, [torch.zeros_like(d) for d in down_block_res_samples] + data.down_block_res_samples = pipeline.guider.prepare_input( + data.down_block_res_samples, [torch.zeros_like(d) for d in data.down_block_res_samples] ) - mid_block_res_sample = pipeline.guider.prepare_input( - mid_block_res_sample, torch.zeros_like(mid_block_res_sample) + data.mid_block_res_sample = pipeline.guider.prepare_input( + data.mid_block_res_sample, torch.zeros_like(data.mid_block_res_sample) ) - latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t) - if num_channels_unet == 9: - latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + data.latent_model_input = pipeline.scheduler.scale_model_input(data.latent_model_input, t) + if data.num_channels_unet == 9: + data.latent_model_input = torch.cat([data.latent_model_input, data.mask, data.masked_image_latents], dim=1) - noise_pred = pipeline.unet( - latent_model_input, + data.noise_pred = pipeline.unet( + data.latent_model_input, t, - encoder_hidden_states=prompt_embeds, - timestep_cond=timestep_cond, - cross_attention_kwargs=cross_attention_kwargs, - added_cond_kwargs=added_cond_kwargs, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, + encoder_hidden_states=data.prompt_embeds, + timestep_cond=data.timestep_cond, + cross_attention_kwargs=data.cross_attention_kwargs, + added_cond_kwargs=data.added_cond_kwargs, + down_block_additional_residuals=data.down_block_res_samples, + mid_block_additional_residual=data.mid_block_res_sample, return_dict=False, )[0] # perform guidance - noise_pred = pipeline.guider.apply_guidance(noise_pred, timestep=t, latents=latents) + data.noise_pred = pipeline.guider.apply_guidance(data.noise_pred, timestep=t, latents=data.latents) # compute the previous noisy sample x_t -> x_t-1 - latents_dtype = latents.dtype - latents = pipeline.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - if latents.dtype != latents_dtype: + data.latents_dtype = data.latents.dtype + data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0] + if data.latents.dtype != data.latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - latents = latents.to(latents_dtype) + data.latents = data.latents.to(data.latents_dtype) - if num_channels_unet == 4 and mask is not None and image_latents is not None: - init_mask = pipeline.guider._maybe_split_prepared_input(mask)[0] - init_latents_proper = image_latents - if i < len(timesteps) - 1: - noise_timestep = timesteps[i + 1] - init_latents_proper = pipeline.scheduler.add_noise( - init_latents_proper, noise, torch.tensor([noise_timestep]) + if data.num_channels_unet == 4 and data.mask is not None and data.image_latents is not None: + data.init_latents_proper = data.image_latents + if i < len(data.timesteps) - 1: + data.noise_timestep = data.timesteps[i + 1] + data.init_latents_proper = pipeline.scheduler.add_noise( + data.init_latents_proper, data.noise, torch.tensor([data.noise_timestep]) ) - latents = (1 - init_mask) * init_latents_proper + init_mask * latents + data.latents = (1 - data.mask) * data.init_latents_proper + data.mask * data.latents - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): + if i == len(data.timesteps) - 1 or ((i + 1) > data.num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): progress_bar.update() pipeline.guider.reset_guider(pipeline) pipeline.controlnet_guider.reset_guider(pipeline) - state.add_intermediate("latents", latents) + + self.add_block_state(state, data) return pipeline, state @@ -1608,20 +1526,23 @@ def inputs(self) -> List[Tuple[str, Any]]: def intermediates_inputs(self) -> List[str]: return [ InputParam("latents", required=True), + InputParam("batch_size", required=True), InputParam("timesteps", required=True), InputParam("num_inference_steps", required=True), InputParam("prompt_embeds", required=True), - InputParam("negative_prompt_embeds", required=True), + InputParam("negative_prompt_embeds"), InputParam("add_time_ids", required=True), - InputParam("negative_add_time_ids", required=True), + InputParam("negative_add_time_ids"), InputParam("pooled_prompt_embeds", required=True), - InputParam("negative_pooled_prompt_embeds", required=True), + InputParam("negative_pooled_prompt_embeds"), InputParam("timestep_cond"), # LCM InputParam("mask"), # inpainting InputParam("masked_image_latents"), # inpainting InputParam("noise"), # inpainting InputParam("image_latents"), # inpainting InputParam("crops_coords"), # inpainting + InputParam("ip_adapter_embeds"), # ip adapter + InputParam("negative_ip_adapter_embeds"), # ip adapter ] @property @@ -1638,234 +1559,246 @@ def __init__(self): control_image_processor = VaeImageProcessor(do_convert_rgb=True, do_normalize=False) self.auxiliaries["control_image_processor"] = control_image_processor + def check_inputs(self, pipeline, data): + + num_channels_unet = pipeline.unet.config.in_channels + if num_channels_unet == 9: + # default case for runwayml/stable-diffusion-inpainting + if data.mask is None or data.masked_image_latents is None: + raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") + num_channels_latents = data.latents.shape[1] + num_channels_mask = data.mask.shape[1] + num_channels_masked_image = data.masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {pipeline.unet.config} expects" + f" {pipeline.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: - guidance_scale = state.get_input("guidance_scale") - guidance_rescale = state.get_input("guidance_rescale") - cross_attention_kwargs = state.get_input("cross_attention_kwargs") - guider_kwargs = state.get_input("guider_kwargs") - generator = state.get_input("generator") - eta = state.get_input("eta") - num_images_per_prompt = state.get_input("num_images_per_prompt") - # controlnet-specific inputs - control_image = state.get_input("control_image") - control_guidance_start = state.get_input("control_guidance_start") - control_guidance_end = state.get_input("control_guidance_end") - controlnet_conditioning_scale = state.get_input("controlnet_conditioning_scale") - control_mode = state.get_input("control_mode") - guess_mode = state.get_input("guess_mode") - - latents = state.get_intermediate("latents") - timesteps = state.get_intermediate("timesteps") - num_inference_steps = state.get_intermediate("num_inference_steps") - - prompt_embeds = state.get_intermediate("prompt_embeds") - negative_prompt_embeds = state.get_intermediate("negative_prompt_embeds") - pooled_prompt_embeds = state.get_intermediate("pooled_prompt_embeds") - negative_pooled_prompt_embeds = state.get_intermediate("negative_pooled_prompt_embeds") - add_time_ids = state.get_intermediate("add_time_ids") - negative_add_time_ids = state.get_intermediate("negative_add_time_ids") - - timestep_cond = state.get_intermediate("timestep_cond") - - # inpainting - mask = state.get_intermediate("mask") - noise = state.get_intermediate("noise") - image_latents = state.get_intermediate("image_latents") - crops_coords = state.get_intermediate("crops_coords") - - device = pipeline._execution_device - batch_size = latents.shape[0] - height, width = latents.shape[-2:] - height = height * pipeline.vae_scale_factor - width = width * pipeline.vae_scale_factor - - # prepare controlnet inputs - controlnet = pipeline.controlnet._orig_mod if is_compiled_module(pipeline.controlnet) else pipeline.controlnet + data = self.get_block_state(state) + self.check_inputs(pipeline, data) - # align format for control guidance - if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): - control_guidance_start = len(control_guidance_end) * [control_guidance_start] - elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): - control_guidance_end = len(control_guidance_start) * [control_guidance_end] + data.num_channels_unet = pipeline.unet.config.in_channels - global_pool_conditions = controlnet.config.global_pool_conditions - guess_mode = guess_mode or global_pool_conditions + # (1) prepare controlnet inputs + data.device = pipeline._execution_device + data.height, data.width = data.latents.shape[-2:] + data.height = data.height * pipeline.vae_scale_factor + data.width = data.width * pipeline.vae_scale_factor - num_control_type = controlnet.config.num_control_type + controlnet = pipeline.controlnet._orig_mod if is_compiled_module(pipeline.controlnet) else pipeline.controlnet - if not isinstance(control_image, list): - control_image = [control_image] + # (1.1) + # control guidance + if not isinstance(data.control_guidance_start, list) and isinstance(data.control_guidance_end, list): + data.control_guidance_start = len(data.control_guidance_end) * [data.control_guidance_start] + elif not isinstance(data.control_guidance_end, list) and isinstance(data.control_guidance_start, list): + data.control_guidance_end = len(data.control_guidance_start) * [data.control_guidance_end] - if not isinstance(control_mode, list): - control_mode = [control_mode] + # (1.2) + # global_pool_conditions & guess_mode + data.global_pool_conditions = controlnet.config.global_pool_conditions + data.guess_mode = data.guess_mode or data.global_pool_conditions - if len(control_image) != len(control_mode): - raise ValueError("Expected len(control_image) == len(control_type)") + # (1.3) + # control_type + data.num_control_type = controlnet.config.num_control_type - control_type = [0 for _ in range(num_control_type)] - for control_idx in control_mode: - control_type[control_idx] = 1 + # (1.4) + # control_type + if not isinstance(data.control_image, list): + data.control_image = [data.control_image] - control_type = torch.Tensor(control_type) + if not isinstance(data.control_mode, list): + data.control_mode = [data.control_mode] - for idx, _ in enumerate(control_image): - control_image[idx] = pipeline.prepare_control_image( - image=control_image[idx], - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, + if len(data.control_image) != len(data.control_mode): + raise ValueError("Expected len(control_image) == len(control_type)") + + data.control_type = [0 for _ in range(data.num_control_type)] + for control_idx in data.control_mode: + data.control_type[control_idx] = 1 + + data.control_type = torch.Tensor(data.control_type) + + # (1.5) + # prepare control_image + for idx, _ in enumerate(data.control_image): + data.control_image[idx] = pipeline.prepare_control_image( + image=data.control_image[idx], + width=data.width, + height=data.height, + batch_size=data.batch_size * data.num_images_per_prompt, + num_images_per_prompt=data.num_images_per_prompt, + device=data.device, dtype=controlnet.dtype, - crops_coords=crops_coords, + crops_coords=data.crops_coords, ) - height, width = control_image[idx].shape[-2:] + data.height, data.width = data.control_image[idx].shape[-2:] - controlnet_keep = [] - for i in range(len(timesteps)): - controlnet_keep.append( + + # (1.6) + # controlnet_keep + data.controlnet_keep = [] + for i in range(len(data.timesteps)): + data.controlnet_keep.append( 1.0 - - float(i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end) + - float(i / len(data.timesteps) < data.control_guidance_start or (i + 1) / len(data.timesteps) > data.control_guidance_end) ) - # Prepare conditional inputs for unet using the guider + # (2) Prepare conditional inputs for unet using the guider # adding default guider arguments: disable_guidance, guidance_scale, guidance_rescale - disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False - guider_kwargs = guider_kwargs or {} - guider_kwargs = { - **guider_kwargs, - "disable_guidance": disable_guidance, - "guidance_scale": guidance_scale, - "guidance_rescale": guidance_rescale, - "batch_size": batch_size, + data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False + data.guider_kwargs = data.guider_kwargs or {} + data.guider_kwargs = { + **data.guider_kwargs, + "disable_guidance": data.disable_guidance, + "guidance_scale": data.guidance_scale, + "guidance_rescale": data.guidance_rescale, + "batch_size": data.batch_size, } - pipeline.guider.set_guider(pipeline, guider_kwargs) - prompt_embeds = pipeline.guider.prepare_input( - prompt_embeds, - negative_prompt_embeds, + pipeline.guider.set_guider(pipeline, data.guider_kwargs) + data.prompt_embeds = pipeline.guider.prepare_input( + data.prompt_embeds, + data.negative_prompt_embeds, ) - add_time_ids = pipeline.guider.prepare_input( - add_time_ids, - negative_add_time_ids, + data.add_time_ids = pipeline.guider.prepare_input( + data.add_time_ids, + data.negative_add_time_ids, ) - pooled_prompt_embeds = pipeline.guider.prepare_input( - pooled_prompt_embeds, - negative_pooled_prompt_embeds, + data.pooled_prompt_embeds = pipeline.guider.prepare_input( + data.pooled_prompt_embeds, + data.negative_pooled_prompt_embeds, ) - added_cond_kwargs = { - "text_embeds": pooled_prompt_embeds, - "time_ids": add_time_ids, + if data.num_channels_unet == 9: + data.mask = pipeline.guider.prepare_input(data.mask, data.mask) + data.masked_image_latents = pipeline.guider.prepare_input(data.masked_image_latents, data.masked_image_latents) + + data.added_cond_kwargs = { + "text_embeds": data.pooled_prompt_embeds, + "time_ids": data.add_time_ids, } - # Prepare conditional inputs for controlnet using the guider - controlnet_disable_guidance = True if disable_guidance or guess_mode else False - controlnet_guider_kwargs = guider_kwargs or {} - controlnet_guider_kwargs = { - **controlnet_guider_kwargs, - "disable_guidance": controlnet_disable_guidance, - "guidance_scale": guidance_scale, - "guidance_rescale": guidance_rescale, - "batch_size": batch_size, + if data.ip_adapter_embeds is not None: + data.ip_adapter_embeds = pipeline.guider.prepare_input(data.ip_adapter_embeds, data.negative_ip_adapter_embeds) + data.added_cond_kwargs["image_embeds"] = data.ip_adapter_embeds + + # (3) Prepare conditional inputs for controlnet using the guider + data.controlnet_disable_guidance = True if data.disable_guidance or data.guess_mode else False + data.controlnet_guider_kwargs = data.guider_kwargs or {} + data.controlnet_guider_kwargs = { + **data.controlnet_guider_kwargs, + "disable_guidance": data.controlnet_disable_guidance, + "guidance_scale": data.guidance_scale, + "guidance_rescale": data.guidance_rescale, + "batch_size": data.batch_size, } - pipeline.controlnet_guider.set_guider(pipeline, controlnet_guider_kwargs) - controlnet_prompt_embeds = pipeline.controlnet_guider.prepare_input(prompt_embeds) - controlnet_added_cond_kwargs = { - "text_embeds": pipeline.controlnet_guider.prepare_input(pooled_prompt_embeds), - "time_ids": pipeline.controlnet_guider.prepare_input(add_time_ids), + pipeline.controlnet_guider.set_guider(pipeline, data.controlnet_guider_kwargs) + data.controlnet_prompt_embeds = pipeline.controlnet_guider.prepare_input(data.prompt_embeds) + data.controlnet_added_cond_kwargs = { + "text_embeds": pipeline.controlnet_guider.prepare_input(data.pooled_prompt_embeds), + "time_ids": pipeline.controlnet_guider.prepare_input(data.add_time_ids), } - for idx, _ in enumerate(control_image): - control_image[idx] = pipeline.controlnet_guider.prepare_input(control_image[idx], control_image[idx]) - - # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta) - num_warmup_steps = max(len(timesteps) - num_inference_steps * pipeline.scheduler.order, 0) + for idx, _ in enumerate(data.control_image): + data.control_image[idx] = pipeline.controlnet_guider.prepare_input(data.control_image[idx], data.control_image[idx]) - control_type = ( - control_type.reshape(1, -1) - .to(device, dtype=prompt_embeds.dtype) + data.control_type = ( + data.control_type.reshape(1, -1) + .to(data.device, dtype=data.prompt_embeds.dtype) ) - control_type = pipeline.controlnet_guider.prepare_input(control_type, control_type) + data.control_type = pipeline.controlnet_guider.prepare_input(data.control_type, data.control_type) - with pipeline.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): + # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + data.extra_step_kwargs = pipeline.prepare_extra_step_kwargs(data.generator, data.eta) + data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0) + + + with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: + for i, t in enumerate(data.timesteps): # prepare latents for unet using the guider - latent_model_input = pipeline.guider.prepare_input(latents, latents) + data.latent_model_input = pipeline.guider.prepare_input(data.latents, data.latents) # prepare latents for controlnet using the guider - control_model_input = pipeline.controlnet_guider.prepare_input(latents, latents) + data.control_model_input = pipeline.controlnet_guider.prepare_input(data.latents, data.latents) - if isinstance(controlnet_keep[i], list): - cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + if isinstance(data.controlnet_keep[i], list): + data.cond_scale = [c * s for c, s in zip(data.controlnet_conditioning_scale, data.controlnet_keep[i])] else: - controlnet_cond_scale = controlnet_conditioning_scale - if isinstance(controlnet_cond_scale, list): - controlnet_cond_scale = controlnet_cond_scale[0] - cond_scale = controlnet_cond_scale * controlnet_keep[i] + data.controlnet_cond_scale = data.controlnet_conditioning_scale + if isinstance(data.controlnet_cond_scale, list): + data.controlnet_cond_scale = data.controlnet_cond_scale[0] + data.cond_scale = data.controlnet_cond_scale * data.controlnet_keep[i] - down_block_res_samples, mid_block_res_sample = pipeline.controlnet( - pipeline.scheduler.scale_model_input(control_model_input, t), + data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet( + pipeline.scheduler.scale_model_input(data.control_model_input, t), t, - encoder_hidden_states=controlnet_prompt_embeds, - controlnet_cond=control_image, - control_type=control_type, - control_type_idx=control_mode, - conditioning_scale=cond_scale, - guess_mode=guess_mode, - added_cond_kwargs=controlnet_added_cond_kwargs, + encoder_hidden_states=data.controlnet_prompt_embeds, + controlnet_cond=data.control_image, + control_type=data.control_type, + control_type_idx=data.control_mode, + conditioning_scale=data.cond_scale, + guess_mode=data.guess_mode, + added_cond_kwargs=data.controlnet_added_cond_kwargs, return_dict=False, ) # when we apply guidance for unet, but not for controlnet: # add 0 to the unconditional batch - down_block_res_samples = pipeline.guider.prepare_input( - down_block_res_samples, [torch.zeros_like(d) for d in down_block_res_samples] + data.down_block_res_samples = pipeline.guider.prepare_input( + data.down_block_res_samples, [torch.zeros_like(d) for d in data.down_block_res_samples] ) - mid_block_res_sample = pipeline.guider.prepare_input( - mid_block_res_sample, torch.zeros_like(mid_block_res_sample) + data.mid_block_res_sample = pipeline.guider.prepare_input( + data.mid_block_res_sample, torch.zeros_like(data.mid_block_res_sample) ) - latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t) + data.latent_model_input = pipeline.scheduler.scale_model_input(data.latent_model_input, t) + if data.num_channels_unet == 9: + data.latent_model_input = torch.cat([data.latent_model_input, data.mask, data.masked_image_latents], dim=1) - noise_pred = pipeline.unet( - latent_model_input, + data.noise_pred = pipeline.unet( + data.latent_model_input, t, - encoder_hidden_states=prompt_embeds, - timestep_cond=timestep_cond, - cross_attention_kwargs=cross_attention_kwargs, - added_cond_kwargs=added_cond_kwargs, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, + encoder_hidden_states=data.prompt_embeds, + timestep_cond=data.timestep_cond, + cross_attention_kwargs=data.cross_attention_kwargs, + added_cond_kwargs=data.added_cond_kwargs, + down_block_additional_residuals=data.down_block_res_samples, + mid_block_additional_residual=data.mid_block_res_sample, return_dict=False, )[0] # perform guidance - noise_pred = pipeline.guider.apply_guidance(noise_pred, timestep=t, latents=latents) + data.noise_pred = pipeline.guider.apply_guidance(data.noise_pred, timestep=t, latents=data.latents) # compute the previous noisy sample x_t -> x_t-1 - latents_dtype = latents.dtype - latents = pipeline.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - if latents.dtype != latents_dtype: + data.latents_dtype = data.latents.dtype + data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0] + if data.latents.dtype != data.latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - latents = latents.to(latents_dtype) - - if mask is not None and image_latents is not None: - init_mask = pipeline.guider._maybe_split_prepared_input(mask)[0] - init_latents_proper = image_latents - if i < len(timesteps) - 1: - noise_timestep = timesteps[i + 1] - init_latents_proper = pipeline.scheduler.add_noise( - init_latents_proper, noise, torch.tensor([noise_timestep]) + data.latents = data.latents.to(data.latents_dtype) + + if data.num_channels_unet == 9 and data.mask is not None and data.image_latents is not None: + data.init_latents_proper = data.image_latents + if i < len(data.timesteps) - 1: + data.noise_timestep = data.timesteps[i + 1] + data.init_latents_proper = pipeline.scheduler.add_noise( + data.init_latents_proper, data.noise, torch.tensor([data.noise_timestep]) ) - latents = (1 - init_mask) * init_latents_proper + init_mask * latents + data.latents = (1 - data.mask) * data.init_latents_proper + data.mask * data.latents - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): + if i == len(data.timesteps) - 1 or ((i + 1) > data.num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): progress_bar.update() pipeline.guider.reset_guider(pipeline) pipeline.controlnet_guider.reset_guider(pipeline) - state.add_intermediate("latents", latents) + + self.add_block_state(state, data) return pipeline, state @@ -1895,56 +1828,54 @@ def __init__(self): @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: - output_type = state.get_input("output_type") - return_dict = state.get_input("return_dict") - - latents = state.get_intermediate("latents") + data = self.get_block_state(state) - if not output_type == "latent": + if not data.output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 - needs_upcasting = pipeline.vae.dtype == torch.float16 and pipeline.vae.config.force_upcast + data.needs_upcasting = pipeline.vae.dtype == torch.float16 and pipeline.vae.config.force_upcast - if needs_upcasting: + if data.needs_upcasting: pipeline.upcast_vae() - latents = latents.to(next(iter(pipeline.vae.post_quant_conv.parameters())).dtype) - elif latents.dtype != pipeline.vae.dtype: + data.latents = data.latents.to(next(iter(pipeline.vae.post_quant_conv.parameters())).dtype) + elif data.latents.dtype != pipeline.vae.dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - pipeline.vae = pipeline.vae.to(latents.dtype) + pipeline.vae = pipeline.vae.to(data.latents.dtype) # unscale/denormalize the latents # denormalize with the mean and std if available and not None - has_latents_mean = ( + data.has_latents_mean = ( hasattr(pipeline.vae.config, "latents_mean") and pipeline.vae.config.latents_mean is not None ) - has_latents_std = ( + data.has_latents_std = ( hasattr(pipeline.vae.config, "latents_std") and pipeline.vae.config.latents_std is not None ) - if has_latents_mean and has_latents_std: - latents_mean = ( - torch.tensor(pipeline.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + if data.has_latents_mean and data.has_latents_std: + data.latents_mean = ( + torch.tensor(pipeline.vae.config.latents_mean).view(1, 4, 1, 1).to(data.latents.device, data.latents.dtype) ) - latents_std = ( - torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + data.latents_std = ( + torch.tensor(pipeline.vae.config.latents_std).view(1, 4, 1, 1).to(data.latents.device, data.latents.dtype) ) - latents = latents * latents_std / pipeline.vae.config.scaling_factor + latents_mean + data.latents = data.latents * data.latents_std / pipeline.vae.config.scaling_factor + data.latents_mean else: - latents = latents / pipeline.vae.config.scaling_factor + data.latents = data.latents / pipeline.vae.config.scaling_factor - image = pipeline.vae.decode(latents, return_dict=False)[0] + data.images = pipeline.vae.decode(data.latents, return_dict=False)[0] # cast back to fp16 if needed - if needs_upcasting: + if data.needs_upcasting: pipeline.vae.to(dtype=torch.float16) else: - image = latents + data.images = data.latents # apply watermark if available if hasattr(pipeline, "watermark") and pipeline.watermark is not None: - image = pipeline.watermark.apply_watermark(image) + data.images = pipeline.watermark.apply_watermark(data.images) + + data.images = pipeline.image_processor.postprocess(data.images, output_type=data.output_type) - image = pipeline.image_processor.postprocess(image, output_type=output_type) - state.add_intermediate("images", image) + self.add_block_state(state, data) return pipeline, state @@ -1970,16 +1901,12 @@ def intermediates_outputs(self) -> List[str]: @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: - original_image = state.get_input("image") - padding_mask_crop = state.get_input("padding_mask_crop") - mask_image = state.get_input("mask_image") - images = state.get_intermediate("images") - crops_coords = state.get_intermediate("crops_coords") + data = self.get_block_state(state) - if padding_mask_crop is not None and crops_coords is not None: - images = [pipeline.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in images] + if data.padding_mask_crop is not None and data.crops_coords is not None: + data.images = [pipeline.image_processor.apply_overlay(data.mask_image, data.image, i, data.crops_coords) for i in data.images] - state.add_intermediate("images", images) + self.add_block_state(state, data) return pipeline, state @@ -2005,14 +1932,13 @@ def outputs(self) -> List[Tuple[str, Any]]: @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: - images = state.get_intermediate("images") - return_dict = state.get_input("return_dict") + data = self.get_block_state(state) - if not return_dict: - output = (images,) + if not data.return_dict: + data.images = (data.images,) else: - output = StableDiffusionXLPipelineOutput(images=images) - state.add_output("images", output) + data.images = StableDiffusionXLPipelineOutput(images=data.images) + self.add_block_state(state, data) return pipeline, state @@ -2076,8 +2002,8 @@ class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks): IMAGE2IMAGE_BLOCKS = OrderedDict([ ("text_encoder", StableDiffusionXLTextEncoderStep), - ("input", StableDiffusionXLInputStep), ("image_encoder", StableDiffusionXLVaeEncoderStep), + ("input", StableDiffusionXLInputStep), ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep), ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), @@ -2087,8 +2013,8 @@ class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks): INPAINT_BLOCKS = OrderedDict([ ("text_encoder", StableDiffusionXLTextEncoderStep), - ("input", StableDiffusionXLInputStep), ("image_encoder", StableDiffusionXLInpaintVaeEncoderStep), + ("input", StableDiffusionXLInputStep), ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), ("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep), ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), @@ -2100,6 +2026,10 @@ class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks): ("denoise", StableDiffusionXLControlNetDenoiseStep), ]) +CONTROLNET_UNION_BLOCKS = OrderedDict([ + ("denoise", StableDiffusionXLControlNetUnionDenoiseStep), +]) + AUTO_BLOCKS = OrderedDict([ ("text_encoder", StableDiffusionXLTextEncoderStep), ("image_encoder", StableDiffusionXLAutoVaeEncoderStep), @@ -2114,6 +2044,7 @@ class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks): "img2img": IMAGE2IMAGE_BLOCKS, "inpaint": INPAINT_BLOCKS, "controlnet": CONTROLNET_BLOCKS, + "controlnet_union": CONTROLNET_UNION_BLOCKS, "auto": AUTO_BLOCKS } @@ -2123,6 +2054,7 @@ class StableDiffusionXLModularPipeline( StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, + ModularIPAdapterMixin, ): @property def default_sample_size(self): From d046cf7d35ae0714f24213aab0c7e4d0030407fe Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 22 Jan 2025 09:48:57 +0100 Subject: [PATCH 052/170] block state + fix for num_images_per_prompt > 1 for denoise/controlnet union etc --- src/diffusers/pipelines/modular_pipeline.py | 49 +++++++++++++++++-- .../pipeline_stable_diffusion_xl_modular.py | 17 ++++--- 2 files changed, 56 insertions(+), 10 deletions(-) diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index 5ead9dbbe5e4..aad473aecea3 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -18,7 +18,6 @@ from dataclasses import dataclass, field from typing import Any, Dict, List, Tuple, Union -from types import SimpleNamespace import torch from tqdm.auto import tqdm @@ -103,6 +102,50 @@ def format_value(v): f")" ) + +@dataclass +class BlockState: + """ + Container for block state data with attribute access and formatted representation. + """ + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + def __repr__(self): + def format_value(v): + # Handle tensors directly + if hasattr(v, "shape") and hasattr(v, "dtype"): + return f"Tensor(dtype={v.dtype}, shape={v.shape})" + + # Handle lists of tensors + elif isinstance(v, list): + if len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): + shapes = [t.shape for t in v] + return f"List[{len(v)}] of Tensors with shapes {shapes}" + return repr(v) + + # Handle tuples of tensors + elif isinstance(v, tuple): + if len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): + shapes = [t.shape for t in v] + return f"Tuple[{len(v)}] of Tensors with shapes {shapes}" + return repr(v) + + # Handle dicts with tensor values + elif isinstance(v, dict): + if any(hasattr(val, "shape") and hasattr(val, "dtype") for val in v.values()): + shapes = {k: val.shape for k, val in v.items() if hasattr(val, "shape")} + return f"Dict of Tensors with shapes {shapes}" + return repr(v) + + # Default case + return repr(v) + + attributes = "\n".join(f" {k}: {format_value(v)}" for k, v in self.__dict__.items()) + return f"BlockState(\n{attributes}\n)" + + @dataclass class InputParam: name: str @@ -487,9 +530,9 @@ def get_block_state(self, state: PipelineState) -> dict: raise ValueError(f"Required intermediate input '{input_param.name}' is missing") data[input_param.name] = value - return SimpleNamespace(**data) + return BlockState(**data) - def add_block_state(self, state: PipelineState, block_state: SimpleNamespace): + def add_block_state(self, state: PipelineState, block_state: BlockState): for output_param in self.intermediates_outputs: if not hasattr(block_state, output_param.name): raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index ab171e365199..76ca184b9064 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -988,8 +988,6 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin data.pooled_prompt_embeds.dtype, text_encoder_projection_dim=data.text_encoder_projection_dim, ) - data.add_time_ids = data.add_time_ids.repeat(data.batch_size * data.num_images_per_prompt, 1).to(device=data.device) - if data.negative_original_size is not None and data.negative_target_size is not None: data.negative_add_time_ids = pipeline._get_add_time_ids( data.negative_original_size, @@ -1000,6 +998,8 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin ) else: data.negative_add_time_ids = data.add_time_ids + + data.add_time_ids = data.add_time_ids.repeat(data.batch_size * data.num_images_per_prompt, 1).to(device=data.device) data.negative_add_time_ids = data.negative_add_time_ids.repeat(data.batch_size * data.num_images_per_prompt, 1).to(device=data.device) # Optionally get Guidance Scale Embedding for LCM @@ -1031,6 +1031,7 @@ def inputs(self) -> List[Tuple[str, Any]]: InputParam("generator", default=None), InputParam("eta", default=0.0), InputParam("guider_kwargs", default=None), + InputParam("num_images_per_prompt", default=1), ] @property @@ -1101,7 +1102,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: "disable_guidance": data.disable_guidance, "guidance_scale": data.guidance_scale, "guidance_rescale": data.guidance_rescale, - "batch_size": data.batch_size, + "batch_size": data.batch_size * data.num_images_per_prompt, } pipeline.guider.set_guider(pipeline, data.guider_kwargs) @@ -1366,7 +1367,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: "disable_guidance": data.disable_guidance, "guidance_scale": data.guidance_scale, "guidance_rescale": data.guidance_rescale, - "batch_size": data.batch_size, + "batch_size": data.batch_size * data.num_images_per_prompt, } pipeline.guider.set_guider(pipeline, data.guider_kwargs) data.prompt_embeds = pipeline.guider.prepare_input( @@ -1402,7 +1403,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: "disable_guidance": data.controlnet_disable_guidance, "guidance_scale": data.guidance_scale, "guidance_rescale": data.guidance_rescale, - "batch_size": data.batch_size, + "batch_size": data.batch_size * data.num_images_per_prompt, } pipeline.controlnet_guider.set_guider(pipeline, data.controlnet_guider_kwargs) data.controlnet_prompt_embeds = pipeline.controlnet_guider.prepare_input(data.prompt_embeds) @@ -1660,7 +1661,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: "disable_guidance": data.disable_guidance, "guidance_scale": data.guidance_scale, "guidance_rescale": data.guidance_rescale, - "batch_size": data.batch_size, + "batch_size": data.batch_size * data.num_images_per_prompt, } pipeline.guider.set_guider(pipeline, data.guider_kwargs) data.prompt_embeds = pipeline.guider.prepare_input( @@ -1697,7 +1698,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: "disable_guidance": data.controlnet_disable_guidance, "guidance_scale": data.guidance_scale, "guidance_rescale": data.guidance_rescale, - "batch_size": data.batch_size, + "batch_size": data.batch_size * data.num_images_per_prompt, } pipeline.controlnet_guider.set_guider(pipeline, data.controlnet_guider_kwargs) data.controlnet_prompt_embeds = pipeline.controlnet_guider.prepare_input(data.prompt_embeds) @@ -1712,6 +1713,8 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.control_type.reshape(1, -1) .to(data.device, dtype=data.prompt_embeds.dtype) ) + repeat_by = data.batch_size * data.num_images_per_prompt // data.control_type.shape[0] + data.control_type = data.control_type.repeat_interleave(repeat_by, dim=0) data.control_type = pipeline.controlnet_guider.prepare_input(data.control_type, data.control_type) # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline From 71df1581f74cc10896e70352463c5b32d077012c Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Wed, 22 Jan 2025 06:19:22 -1000 Subject: [PATCH 053/170] Update src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Álvaro Somoza --- .../stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index ab171e365199..56360c85ba50 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -1338,7 +1338,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: num_images_per_prompt=data.num_images_per_prompt, device=data.device, dtype=controlnet.dtype, - crops_coords=crops_coords, + crops_coords=data.crops_coords, ) control_images.append(control_image) From 00cae4e857d1ffbe71ea80aec8641954e3eaccc1 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 23 Jan 2025 11:07:13 +0100 Subject: [PATCH 054/170] docstring doc doc doc --- src/diffusers/pipelines/modular_pipeline.py | 248 +-- .../pipeline_stable_diffusion_xl_modular.py | 1761 +++++++++++++---- 2 files changed, 1452 insertions(+), 557 deletions(-) diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index aad473aecea3..a4c6baad47f5 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -228,27 +228,26 @@ def format_intermediates_short(intermediates_inputs: List[InputParam], required_ output_parts.extend(outputs) # Combine with arrow notation if both inputs and outputs exist - if input_parts and output_parts: - return f"{', '.join(input_parts)} -> {', '.join(output_parts)}" + if output_parts: + return f"-> {', '.join(output_parts)}" if not input_parts else f"{', '.join(input_parts)} -> {', '.join(output_parts)}" elif input_parts: return ', '.join(input_parts) - elif output_parts: - return ', '.join(output_parts) return "" -def format_input_params(input_params: List[InputParam], indent_level: int = 4, max_line_length: int = 115) -> str: - """Format a list of InputParam objects into a readable string representation. +def format_params(params: List[Union[InputParam, OutputParam]], header: str = "Args", indent_level: int = 4, max_line_length: int = 115) -> str: + """Format a list of InputParam or OutputParam objects into a readable string representation. Args: - input_params: List of InputParam objects to format + params: List of InputParam or OutputParam objects to format + header: Header text to use (e.g. "Args" or "Returns") indent_level: Number of spaces to indent each parameter line (default: 4) max_line_length: Maximum length for each line before wrapping (default: 115) Returns: - A formatted string representing all input parameters + A formatted string representing all parameters """ - if not input_params: + if not params: return "" base_indent = " " * indent_level @@ -270,10 +269,8 @@ def wrap_text(text: str, indent: str, max_length: int) -> str: current_length = 0 for word in words: - # Calculate word length including space word_length = len(word) + (1 if current_line else 0) - # Check if adding this word would exceed the max length if current_line and current_length + word_length > max_length: lines.append(" ".join(current_line)) current_line = [word] @@ -285,22 +282,22 @@ def wrap_text(text: str, indent: str, max_length: int) -> str: if current_line: lines.append(" ".join(current_line)) - # Join lines with proper indentation return f"\n{indent}".join(lines) - # Add the "Args:" header - formatted_params.append(f"{base_indent}Args:") + # Add the header + formatted_params.append(f"{base_indent}{header}:") - for param in input_params: + for param in params: # Format parameter name and type type_str = get_type_str(param.type_hint) if param.type_hint != Any else "" param_str = f"{param_indent}{param.name} (`{type_str}`" - # Add optional tag and default value if parameter is optional - if not param.required: - param_str += ", *optional*" - if param.default is not None: - param_str += f", defaults to {param.default}" + # Add optional tag and default value if parameter is an InputParam and optional + if isinstance(param, InputParam): + if not param.required: + param_str += ", *optional*" + if param.default is not None: + param_str += f", defaults to {param.default}" param_str += "):" # Add description on a new line with additional indentation and wrapping @@ -317,76 +314,49 @@ def wrap_text(text: str, indent: str, max_length: int) -> str: return "\n\n".join(formatted_params) +# Then update the original functions to use this combined version: +def format_input_params(input_params: List[InputParam], indent_level: int = 4, max_line_length: int = 115) -> str: + return format_params(input_params, "Args", indent_level, max_line_length) def format_output_params(output_params: List[OutputParam], indent_level: int = 4, max_line_length: int = 115) -> str: - """Format a list of OutputParam objects into a readable string representation. + return format_params(output_params, "Returns", indent_level, max_line_length) - Args: - output_params: List of OutputParam objects to format - indent_level: Number of spaces to indent each parameter line (default: 4) - max_line_length: Maximum length for each line before wrapping (default: 115) - Returns: - A formatted string representing all output parameters + +def make_doc_string(inputs, intermediates_inputs, intermediates_outputs, final_intermediates_outputs=None, description=""): """ - if not output_params: - return "" - - base_indent = " " * indent_level - param_indent = " " * (indent_level + 4) - desc_indent = " " * (indent_level + 8) - formatted_params = [] + Generates a formatted documentation string describing the pipeline block's parameters and structure. - def get_type_str(type_hint): - if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Union: - types = [t.__name__ if hasattr(t, "__name__") else str(t) for t in type_hint.__args__] - return f"Union[{', '.join(types)}]" - return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint) - - def wrap_text(text: str, indent: str, max_length: int) -> str: - """Wrap text while preserving markdown links and maintaining indentation.""" - words = text.split() - lines = [] - current_line = [] - current_length = 0 + Returns: + str: A formatted string containing information about call parameters, intermediate inputs/outputs, + and final intermediate outputs. + """ + output = "" - for word in words: - word_length = len(word) + (1 if current_line else 0) - - if current_line and current_length + word_length > max_length: - lines.append(" ".join(current_line)) - current_line = [word] - current_length = len(word) - else: - current_line.append(word) - current_length += word_length - - if current_line: - lines.append(" ".join(current_line)) - - return f"\n{indent}".join(lines) - - # Add the "Returns:" header - formatted_params.append(f"{base_indent}Returns:") + if description: + desc_lines = description.strip().split('\n') + aligned_desc = '\n'.join(' ' + line for line in desc_lines) + output += aligned_desc + "\n\n" + + output += format_input_params(inputs + intermediates_inputs, indent_level=2) - for param in output_params: - # Format parameter name and type - type_str = get_type_str(param.type_hint) if param.type_hint != Any else "" - param_str = f"{param_indent}{param.name} (`{type_str}`):" - - # Add description on a new line with additional indentation and wrapping - if param.description: - desc = re.sub( - r'\[(.*?)\]\((https?://[^\s\)]+)\)', - r'[\1](\2)', - param.description - ) - wrapped_desc = wrap_text(desc, desc_indent, max_line_length) - param_str += f"\n{desc_indent}{wrapped_desc}" - - formatted_params.append(param_str) + # YiYi TODO: refactor to remove this and `outputs` attribute instead + if final_intermediates_outputs: + output += "\n\n" + output += format_output_params(final_intermediates_outputs, indent_level=2) + + if intermediates_outputs: + output += "\n\n------------------------\n" + intermediates_str = format_params(intermediates_outputs, "Intermediates Outputs", indent_level=2) + output += intermediates_str - return "\n\n".join(formatted_params) + elif intermediates_outputs: + output +="\n\n" + output += format_output_params(intermediates_outputs, indent_level=2) + + + return output + class PipelineBlock: # YiYi Notes: do we need this? @@ -394,7 +364,11 @@ class PipelineBlock: expected_components = [] expected_configs = [] model_name = None - + + @property + def description(self) -> str: + return "" + @property def inputs(self) -> List[InputParam]: return [] @@ -472,7 +446,7 @@ def __repr__(self): # Intermediates section intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs) - intermediates = f"Intermediates:\n {intermediates_str}" + intermediates = f"Intermediates(`*` = modified):\n {intermediates_str}" return ( f"{class_name}(\n" @@ -484,33 +458,11 @@ def __repr__(self): f")" ) - def get_doc_string(self): - """ - Generates a formatted documentation string describing the pipeline block's parameters and structure. - - Returns: - str: A formatted string containing information about call parameters, intermediate inputs/outputs, - and final intermediate outputs. - """ - output = "Call Parameters:\n" - output += "------------------------\n" - output += format_input_params(self.inputs, indent_level=2) - - output += "\n\nIntermediate inputs:\n" - output += "--------------------------\n" - output += format_input_params(self.intermediates_inputs, indent_level=2) - if hasattr(self, "intermediates_outputs"): - output += "\n\nIntermediate outputs:\n" - output += "--------------------------\n" - output += format_output_params(self.intermediates_outputs, indent_level=2) - - if hasattr(self, "final_intermediates_outputs"): - output += "\nFinal intermediate outputs:\n" - output += "--------------------------\n" - output += format_output_params(self.final_intermediates_outputs, indent_level=2) + @property + def doc(self): + return make_doc_string(self.inputs, self.intermediates_inputs, self.intermediates_outputs, None, self.description) - return output def get_block_state(self, state: PipelineState) -> dict: """Get all inputs and intermediates in one dictionary""" @@ -643,6 +595,10 @@ def __init__(self): @property def model_name(self): return next(iter(self.blocks.values())).model_name + + @property + def description(self): + return "" @property def expected_components(self): @@ -849,7 +805,7 @@ def __repr__(self): inputs_str = format_inputs_short(block.inputs) sections.append(f" inputs:\n {inputs_str}") - intermediates_str = f" intermediates:\n {format_intermediates_short(block.intermediates_inputs, block.required_intermediates_inputs, block.intermediates_outputs)}" + intermediates_str = f" intermediates(`*` = modified):\n {format_intermediates_short(block.intermediates_inputs, block.required_intermediates_inputs, block.intermediates_outputs)}" sections.append(intermediates_str) sections.append("") @@ -861,33 +817,9 @@ def __repr__(self): f")" ) - def get_doc_string(self): - """ - Generates a formatted documentation string describing the pipeline block's parameters and structure. - - Returns: - str: A formatted string containing information about call parameters, intermediate inputs/outputs, - and final intermediate outputs. - """ - output = "Call Parameters:\n" - output += "------------------------\n" - output += format_input_params(self.inputs, indent_level=2) - - output += "\n\nIntermediate inputs:\n" - output += "--------------------------\n" - output += format_input_params(self.intermediates_inputs, indent_level=2) - - if hasattr(self, "intermediates_outputs"): - output += "\n\nIntermediate outputs:\n" - output += "--------------------------\n" - output += format_output_params(self.intermediates_outputs, indent_level=2) - - if hasattr(self, "final_intermediates_outputs"): - output += "\nFinal intermediate outputs:\n" - output += "--------------------------\n" - output += format_output_params(self.final_intermediates_outputs, indent_level=2) - - return output + @property + def doc(self): + return make_doc_string(self.inputs, self.intermediates_inputs, self.intermediates_outputs, None, self.description) class SequentialPipelineBlocks: """ @@ -899,6 +831,10 @@ class SequentialPipelineBlocks: @property def model_name(self): return next(iter(self.blocks.values())).model_name + + @property + def description(self): + return "" @property def expected_components(self): @@ -1192,7 +1128,7 @@ def __repr__(self): intermediates_str = format_intermediates_short(block.intermediates_inputs, block.required_intermediates_inputs, block.intermediates_outputs) if intermediates_str: - blocks_str += f" intermediates: {intermediates_str}\n" + blocks_str += f" intermediates(`*` = modified): {intermediates_str}\n" blocks_str += "\n" inputs_str = format_inputs_short(self.inputs) @@ -1220,33 +1156,9 @@ def __repr__(self): f")" ) - def get_doc_string(self): - """ - Generates a formatted documentation string describing the pipeline block's parameters and structure. - - Returns: - str: A formatted string containing information about call parameters, intermediate inputs/outputs, - and final intermediate outputs. - """ - output = "Call Parameters:\n" - output += "------------------------\n" - output += format_input_params(self.inputs, indent_level=2) - - output += "\n\nIntermediate inputs:\n" - output += "--------------------------\n" - output += format_input_params(self.intermediates_inputs, indent_level=2) - - if hasattr(self, "intermediates_outputs"): - output += "\n\nIntermediate outputs:\n" - output += "--------------------------\n" - output += format_output_params(self.intermediates_outputs, indent_level=2) - - if hasattr(self, "final_intermediates_outputs"): - output += "\nFinal intermediate outputs:\n" - output += "--------------------------\n" - output += format_output_params(self.final_intermediates_outputs, indent_level=2) - - return output + @property + def doc(self): + return make_doc_string(self.inputs, self.intermediates_inputs, self.intermediates_outputs, self.final_intermediates_outputs, self.description) class ModularPipeline(ConfigMixin): """ @@ -1467,8 +1379,9 @@ def default_call_parameters(self) -> Dict[str, Any]: def __repr__(self): output = "ModularPipeline:\n" output += "==============================\n\n" - + block = self.pipeline_block + if hasattr(block, "trigger_inputs") and block.trigger_inputs: output += "\n" output += " Trigger Inputs:\n" @@ -1514,7 +1427,10 @@ def __repr__(self): output += "\n" # List the call parameters - output += self.pipeline_block.get_doc_string() + full_doc = self.pipeline_block.doc + if "------------------------" in full_doc: + full_doc = full_doc.split("------------------------")[0].rstrip() + output += full_doc return output diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 47b0e454a8c0..f42c2359a426 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -13,14 +13,14 @@ # limitations under the License. import inspect -from typing import Any, List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union, Dict import PIL import torch from collections import OrderedDict from ...guider import CFGGuider -from ...image_processor import VaeImageProcessor +from ...image_processor import VaeImageProcessor, PipelineImageInput from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin from ...models import ControlNetModel, ImageProjection from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor @@ -47,6 +47,7 @@ StableDiffusionXLPipelineOutput, ) +import numpy as np logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -126,114 +127,92 @@ def retrieve_latents( -class StableDiffusionXLInputStep(PipelineBlock): +class StableDiffusionXLLoraStep(PipelineBlock): + expected_components = ["text_encoder", "text_encoder_2", "unet"] + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return ( + "Lora step that handles all the lora related tasks: load/unload lora weights into unet and text encoders, manage lora adapters etc" + " See [StableDiffusionXLLoraLoaderMixin](https://huggingface.co/docs/diffusers/api/loaders/lora#diffusers.loaders.StableDiffusionXLLoraLoaderMixin)" + " for more details" + ) + + def __init__(self): + super().__init__() + self.components["text_encoder"] = None + self.components["text_encoder_2"] = None + self.components["unet"] = None + + @torch.no_grad() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + raise EnvironmentError("StableDiffusionXLLoraStep is desgined to be used to load lora weights, __call__ is not implemented") + + +class StableDiffusionXLIPAdapterStep(PipelineBlock): + expected_components = ["image_encoder", "feature_extractor", "unet"] model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return ( + "IP Adapter step that handles all the ip adapter related tasks: Load/unload ip adapter weights into unet, prepare ip adapter image embeddings, etc" + " See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)" + " for more details" + ) + @property def inputs(self) -> List[InputParam]: return [ InputParam( - name="num_images_per_prompt", - type_hint=int, - default=1, - description="The number of images to generate per prompt.", + "ip_adapter_image", + required=True, + type_hint=PipelineImageInput, + description="The image(s) to be used as ip adapter" + ), + InputParam( + "guidance_scale", + default=5.0, + description="Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). Guidance scale is enabled by setting `guidance_scale > 1`." ), ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam("prompt_embeds", required=True), - InputParam("negative_prompt_embeds"), - InputParam("pooled_prompt_embeds", required=True), - InputParam("negative_pooled_prompt_embeds"), - InputParam("ip_adapter_embeds"), - InputParam("negative_ip_adapter_embeds"), - ] @property def intermediates_outputs(self) -> List[str]: - return [ - OutputParam("batch_size"), - OutputParam("dtype"), - OutputParam("prompt_embeds"), - OutputParam("negative_prompt_embeds"), - OutputParam("pooled_prompt_embeds"), - OutputParam("negative_pooled_prompt_embeds"), - OutputParam("ip_adapter_embeds"), - OutputParam("negative_ip_adapter_embeds"), - ] + return [OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"), + OutputParam("negative_ip_adapter_embeds", type_hint=torch.Tensor, description="Negative IP adapter image embeddings")] + + def __init__(self): + super().__init__() + self.components["image_encoder"] = None + self.components["feature_extractor"] = None + self.components["unet"] = None - def check_inputs(self, pipeline, data): - - if data.prompt_embeds is not None and data.negative_prompt_embeds is not None: - if data.prompt_embeds.shape != data.negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {data.prompt_embeds.shape} != `negative_prompt_embeds`" - f" {data.negative_prompt_embeds.shape}." - ) - - if data.prompt_embeds is not None and data.pooled_prompt_embeds is None: - raise ValueError( - "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." - ) - - if data.negative_prompt_embeds is not None and data.negative_pooled_prompt_embeds is None: - raise ValueError( - "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." - ) - - if data.ip_adapter_embeds is not None and not isinstance(data.ip_adapter_embeds, list): - raise ValueError("`ip_adapter_embeds` must be a list") - - if data.negative_ip_adapter_embeds is not None and not isinstance(data.negative_ip_adapter_embeds, list): - raise ValueError("`negative_ip_adapter_embeds` must be a list") - - if data.ip_adapter_embeds is not None and data.negative_ip_adapter_embeds is not None: - for i, ip_adapter_embed in enumerate(data.ip_adapter_embeds): - if ip_adapter_embed.shape != data.negative_ip_adapter_embeds[i].shape: - raise ValueError( - "`ip_adapter_embeds` and `negative_ip_adapter_embeds` must have the same shape when passed directly, but" - f" got: `ip_adapter_embeds` {ip_adapter_embed.shape} != `negative_ip_adapter_embeds`" - f" {data.negative_ip_adapter_embeds[i].shape}." - ) @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: data = self.get_block_state(state) - self.check_inputs(pipeline, data) - data.batch_size = data.prompt_embeds.shape[0] - data.dtype = data.prompt_embeds.dtype + data.do_classifier_free_guidance = data.guidance_scale > 1.0 + data.device = pipeline._execution_device - _, seq_len, _ = data.prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - data.prompt_embeds = data.prompt_embeds.repeat(1, data.num_images_per_prompt, 1) - data.prompt_embeds = data.prompt_embeds.view(data.batch_size * data.num_images_per_prompt, seq_len, -1) - - if data.negative_prompt_embeds is not None: - _, seq_len, _ = data.negative_prompt_embeds.shape - data.negative_prompt_embeds = data.negative_prompt_embeds.repeat(1, data.num_images_per_prompt, 1) - data.negative_prompt_embeds = data.negative_prompt_embeds.view(data.batch_size * data.num_images_per_prompt, seq_len, -1) - - data.pooled_prompt_embeds = data.pooled_prompt_embeds.repeat(1, data.num_images_per_prompt, 1) - data.pooled_prompt_embeds = data.pooled_prompt_embeds.view(data.batch_size * data.num_images_per_prompt, -1) - - if data.negative_pooled_prompt_embeds is not None: - data.negative_pooled_prompt_embeds = data.negative_pooled_prompt_embeds.repeat(1, data.num_images_per_prompt, 1) - data.negative_pooled_prompt_embeds = data.negative_pooled_prompt_embeds.view(data.batch_size * data.num_images_per_prompt, -1) - - if data.ip_adapter_embeds is not None: - for i, ip_adapter_embed in enumerate(data.ip_adapter_embeds): - data.ip_adapter_embeds[i] = torch.cat([ip_adapter_embed] * data.num_images_per_prompt, dim=0) + data.ip_adapter_embeds = pipeline.prepare_ip_adapter_image_embeds( + ip_adapter_image=data.ip_adapter_image, + ip_adapter_image_embeds=None, + device=data.device, + num_images_per_prompt=1, + do_classifier_free_guidance=data.do_classifier_free_guidance, + ) + if data.do_classifier_free_guidance: + data.negative_ip_adapter_embeds = [] + for i, image_embeds in enumerate(data.ip_adapter_embeds): + negative_image_embeds, image_embeds = image_embeds.chunk(2) + data.negative_ip_adapter_embeds.append(negative_image_embeds) + data.ip_adapter_embeds[i] = image_embeds - if data.negative_ip_adapter_embeds is not None: - for i, negative_ip_adapter_embed in enumerate(data.negative_ip_adapter_embeds): - data.negative_ip_adapter_embeds[i] = torch.cat([negative_ip_adapter_embed] * data.num_images_per_prompt, dim=0) - self.add_block_state(state, data) - return pipeline, state @@ -242,13 +221,20 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock): expected_configs = ["force_zeros_for_empty_prompt"] model_name = "stable-diffusion-xl" + @property + def description(self) -> str: + return( + "Text Encoder step that generate text_embeddings to guide the image generation" + ) + + @property def inputs(self) -> List[InputParam]: return [ InputParam( name="prompt", type_hint=Union[str, List[str]], - description="The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds` instead.", + description="The prompt or prompts to guide the image generation.", ), InputParam( name="prompt_2", @@ -258,7 +244,7 @@ def inputs(self) -> List[InputParam]: InputParam( name="negative_prompt", type_hint=Union[str, List[str]], - description="The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).", + description="The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).", ), InputParam( name="negative_prompt_2", @@ -286,10 +272,10 @@ def inputs(self) -> List[InputParam]: @property def intermediates_outputs(self) -> List[str]: return [ - OutputParam("prompt_embeds"), - OutputParam("negative_prompt_embeds"), - OutputParam("pooled_prompt_embeds"), - OutputParam("negative_pooled_prompt_embeds"), + OutputParam("prompt_embeds", type_hint=torch.Tensor, description="text embeddings used to guide the image generation"), + OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="negative text embeddings used to guide the image generation"), + OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, description="pooled text embeddings used to guide the image generation"), + OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="negative pooled text embeddings used to guide the image generation"), ] def __init__(self): @@ -351,24 +337,55 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock): expected_components = ["vae"] model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return ( + "Vae Encoder step that encode the input image into a latent representation" + ) + @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="image", required=True), - InputParam(name="generator"), - InputParam(name="height"), - InputParam(name="width"), + InputParam( + name="image", + type_hint=PipelineImageInput, + required=True, + description="The image(s) to modify with the pipeline, for img2img or inpainting task. When using for inpainting task, parts of the image will be masked out with `mask_image` and repainted according to `prompt`." + ), + InputParam( + name="generator", + type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], + description="One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)" + "to make generation deterministic." + ), + InputParam( + name="height", + type_hint=Optional[int], + description="The height in pixels of the generated image. This is set to 1024 by default for the best results. " + "Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0]" + "(https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not " + "specifically fine-tuned on low resolutions.", + ), + InputParam( + name="width", + type_hint=Optional[int], + description="The width in pixels of the generated image. This is set to 1024 by default for the best results. " + "Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0]" + "(https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not " + "specifically fine-tuned on low resolutions.", + ), ] @property def intermediates_inputs(self) -> List[str]: return [ - InputParam("dtype"), - InputParam("preprocess_kwargs")] + InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), + InputParam("preprocess_kwargs", type_hint=Optional[dict], description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]")] @property def intermediates_outputs(self) -> List[str]: - return [OutputParam("image_latents")] + return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation")] def __init__(self): super().__init__() @@ -402,92 +419,323 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: return pipeline, state -class StableDiffusionXLLoraStep(PipelineBlock): - expected_components = ["text_encoder", "text_encoder_2", "unet"] +class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): + expected_components = ["vae"] model_name = "stable-diffusion-xl" + @property + def description(self) -> str: + return ( + "Vae encoder step that prepares the image and mask for the inpainting process" + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + "height", + type_hint=Optional[int], + description="The height in pixels of the generated image. This is set to 1024 by default for the best results. " + "Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0]" + "(https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not " + "specifically fine-tuned on low resolutions.", + ), + InputParam( + "width", + type_hint=Optional[int], + description="The width in pixels of the generated image. This is set to 1024 by default for the best results. " + "Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0]" + "(https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not " + "specifically fine-tuned on low resolutions.", + ), + InputParam( + "generator", + type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], + description="One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) " + "to make generation deterministic." + ), + InputParam( + "image", + required=True, + type_hint=PipelineImageInput, + description="The image(s) to modify with the pipeline, for img2img or inpainting task. When using for inpainting task, parts of the image will be masked out with `mask_image` and repainted according to `prompt`." + ), + InputParam( + "mask_image", + required=True, + type_hint=PipelineImageInput, + description="`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be " + "repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted " + "to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) " + "instead of 3, so the expected shape would be `(B, H, W, 1)`." + ), + InputParam( + "padding_mask_crop", + type_hint=Optional[Tuple[int, int]], + description="The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to " + "image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region " + "with the same aspect ratio of the image and contains all masked area, and then expand that area based " + "on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before " + "resizing to the original image size for inpainting. This is useful when the masked area is small while " + "the image is large and contain information irrelevant for inpainting, such as background." + ), + ] + + @property + def intermediates_inputs(self) -> List[InputParam]: + return [InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs")] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representation of the input image"), + OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"), + OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)"), + OutputParam("crops_coords", type_hint=Optional[Tuple[int, int]], description="The crop coordinates to use for the preprocess/postprocess of the image and mask")] + def __init__(self): super().__init__() - self.components["text_encoder"] = None - self.components["text_encoder_2"] = None - self.components["unet"] = None - + self.auxiliaries["image_processor"] = VaeImageProcessor() + self.auxiliaries["mask_processor"] = VaeImageProcessor(do_normalize=False, do_binarize=True, do_convert_grayscale=True) + self.components["vae"] = None + @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - raise EnvironmentError("StableDiffusionXLLoraStep is desgined to be used to load lora weights, __call__ is not implemented") + def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: + data = self.get_block_state(state) -class StableDiffusionXLIPAdapterStep(PipelineBlock): - expected_components = ["image_encoder", "feature_extractor", "unet"] + data.dtype = data.dtype if data.dtype is not None else pipeline.vae.dtype + data.device = pipeline._execution_device + + if data.padding_mask_crop is not None: + data.crops_coords = pipeline.mask_processor.get_crop_region(data.mask_image, data.width, data.height, pad=data.padding_mask_crop) + data.resize_mode = "fill" + else: + data.crops_coords = None + data.resize_mode = "default" + + data.image = pipeline.image_processor.preprocess(data.image, height=data.height, width=data.width, crops_coords=data.crops_coords, resize_mode=data.resize_mode) + data.image = data.image.to(dtype=torch.float32) + + data.mask = pipeline.mask_processor.preprocess(data.mask_image, height=data.height, width=data.width, resize_mode=data.resize_mode, crops_coords=data.crops_coords) + data.masked_image = data.image * (data.mask < 0.5) + + data.batch_size = data.image.shape[0] + data.image = data.image.to(device=data.device, dtype=data.dtype) + data.image_latents = pipeline._encode_vae_image(image=data.image, generator=data.generator) + + # 7. Prepare mask latent variables + data.mask, data.masked_image_latents = pipeline.prepare_mask_latents( + data.mask, + data.masked_image, + data.batch_size, + data.height, + data.width, + data.dtype, + data.device, + data.generator, + ) + + self.add_block_state(state, data) + + + return pipeline, state + + +class StableDiffusionXLInputStep(PipelineBlock): model_name = "stable-diffusion-xl" + @property + def description(self) -> str: + return ( + "Input processing step that:\n" + " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" + " 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`\n\n" + "All input tensors are expected to have either batch_size=1 or match the batch_size\n" + "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n" + "have a final batch_size of batch_size * num_images_per_prompt." + ) + @property def inputs(self) -> List[InputParam]: return [ - InputParam("ip_adapter_image", required=True), - InputParam("guidance_scale", default=5.0), + InputParam( + name="num_images_per_prompt", + type_hint=int, + default=1, + description="The number of images to generate per prompt.", + ), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam("prompt_embeds", required=True, type_hint=torch.Tensor, description="Pre-generated text embeddings. Can be generated from text_encoder step."), + InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Pre-generated negative text embeddings. Can be generated from text_encoder step."), + InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor, description="Pre-generated pooled text embeddings. Can be generated from text_encoder step."), + InputParam("negative_pooled_prompt_embeds", description="Pre-generated negative pooled text embeddings. Can be generated from text_encoder step."), + InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Pre-generated image embeddings for IP-Adapter. Can be generated from ip_adapter step."), + InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Pre-generated negative image embeddings for IP-Adapter. Can be generated from ip_adapter step."), ] @property def intermediates_outputs(self) -> List[str]: - return [OutputParam("ip_adapter_embeds"), OutputParam("negative_ip_adapter_embeds")] - - def __init__(self): - super().__init__() - self.components["image_encoder"] = None - self.components["feature_extractor"] = None - self.components["unet"] = None + return [ + OutputParam("batch_size", type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt"), + OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs (determined by `prompt_embeds`)"), + OutputParam("prompt_embeds", type_hint=torch.Tensor, description="text embeddings used to guide the image generation"), + OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="negative text embeddings used to guide the image generation"), + OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, description="pooled text embeddings used to guide the image generation"), + OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="negative pooled text embeddings used to guide the image generation"), + OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="image embeddings for IP-Adapter"), + OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="negative image embeddings for IP-Adapter"), + ] + def check_inputs(self, pipeline, data): + + if data.prompt_embeds is not None and data.negative_prompt_embeds is not None: + if data.prompt_embeds.shape != data.negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {data.prompt_embeds.shape} != `negative_prompt_embeds`" + f" {data.negative_prompt_embeds.shape}." + ) + + if data.prompt_embeds is not None and data.pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if data.negative_prompt_embeds is not None and data.negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if data.ip_adapter_embeds is not None and not isinstance(data.ip_adapter_embeds, list): + raise ValueError("`ip_adapter_embeds` must be a list") + + if data.negative_ip_adapter_embeds is not None and not isinstance(data.negative_ip_adapter_embeds, list): + raise ValueError("`negative_ip_adapter_embeds` must be a list") + + if data.ip_adapter_embeds is not None and data.negative_ip_adapter_embeds is not None: + for i, ip_adapter_embed in enumerate(data.ip_adapter_embeds): + if ip_adapter_embed.shape != data.negative_ip_adapter_embeds[i].shape: + raise ValueError( + "`ip_adapter_embeds` and `negative_ip_adapter_embeds` must have the same shape when passed directly, but" + f" got: `ip_adapter_embeds` {ip_adapter_embed.shape} != `negative_ip_adapter_embeds`" + f" {data.negative_ip_adapter_embeds[i].shape}." + ) @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: data = self.get_block_state(state) + self.check_inputs(pipeline, data) - data.do_classifier_free_guidance = data.guidance_scale > 1.0 - data.device = pipeline._execution_device + data.batch_size = data.prompt_embeds.shape[0] + data.dtype = data.prompt_embeds.dtype - data.ip_adapter_embeds = pipeline.prepare_ip_adapter_image_embeds( - ip_adapter_image=data.ip_adapter_image, - ip_adapter_image_embeds=None, - device=data.device, - num_images_per_prompt=1, - do_classifier_free_guidance=data.do_classifier_free_guidance, - ) - if data.do_classifier_free_guidance: - data.negative_ip_adapter_embeds = [] - for i, image_embeds in enumerate(data.ip_adapter_embeds): - negative_image_embeds, image_embeds = image_embeds.chunk(2) - data.negative_ip_adapter_embeds.append(negative_image_embeds) - data.ip_adapter_embeds[i] = image_embeds + _, seq_len, _ = data.prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + data.prompt_embeds = data.prompt_embeds.repeat(1, data.num_images_per_prompt, 1) + data.prompt_embeds = data.prompt_embeds.view(data.batch_size * data.num_images_per_prompt, seq_len, -1) + + if data.negative_prompt_embeds is not None: + _, seq_len, _ = data.negative_prompt_embeds.shape + data.negative_prompt_embeds = data.negative_prompt_embeds.repeat(1, data.num_images_per_prompt, 1) + data.negative_prompt_embeds = data.negative_prompt_embeds.view(data.batch_size * data.num_images_per_prompt, seq_len, -1) + + data.pooled_prompt_embeds = data.pooled_prompt_embeds.repeat(1, data.num_images_per_prompt, 1) + data.pooled_prompt_embeds = data.pooled_prompt_embeds.view(data.batch_size * data.num_images_per_prompt, -1) + + if data.negative_pooled_prompt_embeds is not None: + data.negative_pooled_prompt_embeds = data.negative_pooled_prompt_embeds.repeat(1, data.num_images_per_prompt, 1) + data.negative_pooled_prompt_embeds = data.negative_pooled_prompt_embeds.view(data.batch_size * data.num_images_per_prompt, -1) + + if data.ip_adapter_embeds is not None: + for i, ip_adapter_embed in enumerate(data.ip_adapter_embeds): + data.ip_adapter_embeds[i] = torch.cat([ip_adapter_embed] * data.num_images_per_prompt, dim=0) + if data.negative_ip_adapter_embeds is not None: + for i, negative_ip_adapter_embed in enumerate(data.negative_ip_adapter_embeds): + data.negative_ip_adapter_embeds[i] = torch.cat([negative_ip_adapter_embed] * data.num_images_per_prompt, dim=0) + self.add_block_state(state, data) - return pipeline, state + return pipeline, state class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): expected_components = ["scheduler"] model_name = "stable-diffusion-xl" + @property + def description(self) -> str: + return ( + "Step that sets the timesteps for the scheduler and determines the initial noise level (latent_timestep) for image-to-image/inpainting generation." + "The latent_timestep is calculated from the `strength` parameter - higher strength means starting from a noisier version of the input image." + ) + @property def inputs(self) -> List[InputParam]: return [ - InputParam("num_inference_steps", default=50), - InputParam("timesteps"), - InputParam("sigmas"), - InputParam("denoising_end"), - InputParam("strength", default=0.3), - InputParam("denoising_start"), - InputParam("num_images_per_prompt", default=1), + InputParam( + "num_inference_steps", + default=50, + type_hint=int, + description="The number of denoising steps. More denoising steps usually lead to a higher quality image at the" + " expense of slower inference." + ), + InputParam( + "timesteps", + type_hint=Optional[torch.Tensor], + description="Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order." + ), + InputParam( + "sigmas", + type_hint=Optional[torch.Tensor], + description="Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used." + ), + InputParam( + "denoising_end", + type_hint=Optional[float], + description="When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be completed before it is intentionally prematurely terminated. As a result, the returned sample will still retain a substantial amount of noise as determined by the discrete timesteps selected by the scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a 'Mixture of Denoisers' multi-pipeline setup." + ), + InputParam( + "strength", + default=0.3, + type_hint=float, + description="Conceptually, indicates how much to transform the reference `image` (the masked portion of image for inpainting). Must be between 0 and 1. `image` " + "will be used as a starting point, adding more noise to it the larger the `strength`. The number of " + "denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will " + "be maximum and the denoising process will run for the full number of iterations specified in " + "`num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of " + "`denoising_start` being declared as an integer, the value of `strength` will be ignored." + ), + InputParam( + "denoising_start", + type_hint=Optional[float], + description="The denoising start value to use for the scheduler. Determines the starting point of the denoising process." + ), + InputParam( + "num_images_per_prompt", + default=1, + type_hint=int, + description="The number of images to generate per prompt. Defaults to 1." + ), ] @property def intermediates_inputs(self) -> List[str]: - return [InputParam("batch_size", required=True)] + return [ + InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt"), + ] @property def intermediates_outputs(self) -> List[str]: - return [OutputParam("timesteps"), OutputParam("num_inference_steps"), OutputParam("latent_timestep")] + return [ + OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), + OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time"), + OutputParam("latent_timestep", type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image generation") + ] def __init__(self): super().__init__() @@ -532,19 +780,43 @@ def denoising_value_valid(dnv): class StableDiffusionXLSetTimestepsStep(PipelineBlock): expected_components = ["scheduler"] model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return ( + "Step that sets the scheduler's timesteps for inference" + ) @property def inputs(self) -> List[InputParam]: return [ - InputParam("num_inference_steps", default=50), - InputParam("timesteps"), - InputParam("sigmas"), - InputParam("denoising_end"), + InputParam( + "num_inference_steps", + default=50, + type_hint=int, + description="The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference." + ), + InputParam( + "timesteps", + type_hint=Optional[torch.Tensor], + description="Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order." + ), + InputParam( + "sigmas", + type_hint=Optional[torch.Tensor], + description="Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used." + ), + InputParam( + "denoising_end", + type_hint=Optional[float], + description="When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be completed before it is intentionally prematurely terminated. As a result, the returned sample will still retain a substantial amount of noise as determined by the discrete timesteps selected by the scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a 'Mixture of Denoisers' multi-pipeline setup." + ), ] @property def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("timesteps"), OutputParam("num_inference_steps")] + return [OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), + OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time")] def __init__(self): super().__init__() @@ -574,105 +846,98 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: return pipeline, state -class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): - expected_components = ["vae"] +class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): + expected_components = ["scheduler"] model_name = "stable-diffusion-xl" @property - def inputs(self) -> List[InputParam]: - return [ - InputParam("height"), - InputParam("width"), - InputParam("generator"), - InputParam("image", required=True), - InputParam("mask_image", required=True), - InputParam("padding_mask_crop"), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [InputParam("dtype")] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("image_latents"), OutputParam("mask"), OutputParam("masked_image_latents"), OutputParam("crops_coords")] - - def __init__(self): - super().__init__() - self.auxiliaries["image_processor"] = VaeImageProcessor() - self.auxiliaries["mask_processor"] = VaeImageProcessor(do_normalize=False, do_binarize=True, do_convert_grayscale=True) - self.components["vae"] = None - - @torch.no_grad() - def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - - data = self.get_block_state(state) - - data.dtype = data.dtype if data.dtype is not None else pipeline.vae.dtype - data.device = pipeline._execution_device - - if data.padding_mask_crop is not None: - data.crops_coords = pipeline.mask_processor.get_crop_region(data.mask_image, data.width, data.height, pad=data.padding_mask_crop) - data.resize_mode = "fill" - else: - data.crops_coords = None - data.resize_mode = "default" - - data.image = pipeline.image_processor.preprocess(data.image, height=data.height, width=data.width, crops_coords=data.crops_coords, resize_mode=data.resize_mode) - data.image = data.image.to(dtype=torch.float32) - - data.mask = pipeline.mask_processor.preprocess(data.mask_image, height=data.height, width=data.width, resize_mode=data.resize_mode, crops_coords=data.crops_coords) - data.masked_image = data.image * (data.mask < 0.5) - - data.batch_size = data.image.shape[0] - data.image = data.image.to(device=data.device, dtype=data.dtype) - data.image_latents = pipeline._encode_vae_image(image=data.image, generator=data.generator) - - # 7. Prepare mask latent variables - data.mask, data.masked_image_latents = pipeline.prepare_mask_latents( - data.mask, - data.masked_image, - data.batch_size, - data.height, - data.width, - data.dtype, - data.device, - data.generator, + def description(self) -> str: + return ( + "Step that prepares the latents for the inpainting process" ) - self.add_block_state(state, data) - - - return pipeline, state - - -class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): - expected_components = ["scheduler"] - model_name = "stable-diffusion-xl" - @property def inputs(self) -> List[Tuple[str, Any]]: return [ - InputParam("generator"), - InputParam("latents"), - InputParam("num_images_per_prompt", default=1), - InputParam("denoising_start"), - InputParam("strength", default=0.9999), + InputParam( + "generator", + type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], + description="One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) " + "to make generation deterministic."), + InputParam( + "latents", + type_hint=Optional[torch.Tensor], + description="Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`." + ), + InputParam( + "num_images_per_prompt", + default=1, + type_hint=int, + description="The number of images to generate per prompt" + ), + InputParam( + "denoising_start", + type_hint=Optional[float], + description="When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be bypassed before it is initiated. The initial part of the denoising process is skipped and it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, strength will be ignored. Useful for 'Mixture of Denoisers' multi-pipeline setups." + ), + InputParam( + "strength", + default=0.9999, + type_hint=float, + description="Conceptually, indicates how much to transform the reference `image` (the masked portion of image for inpainting). Must be between 0 and 1. `image` " + "will be used as a starting point, adding more noise to it the larger the `strength`. The number of " + "denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will " + "be maximum and the denoising process will run for the full number of iterations specified in " + "`num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of " + "`denoising_start` being declared as an integer, the value of `strength` will be ignored." + ), ] @property def intermediates_inputs(self) -> List[str]: return [ - InputParam("batch_size", required=True), - InputParam("latent_timestep", required=True), - InputParam("image_latents", required=True), - InputParam("mask", required=True), - InputParam("masked_image_latents"), # only for inpainting-specific unet - InputParam("dtype")] + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + ), + InputParam( + "latent_timestep", + required=True, + type_hint=torch.Tensor, + description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step." + ), + InputParam( + "image_latents", + required=True, + type_hint=torch.Tensor, + description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step." + ), + InputParam( + "mask", + required=True, + type_hint=torch.Tensor, + description="The mask for the inpainting generation. Can be generated in vae_encode step." + ), + InputParam( + "masked_image_latents", + type_hint=torch.Tensor, + description="The masked image latents for the inpainting generation (only for inpainting-specific unet). Can be generated in vae_encode step." + ), + InputParam( + "dtype", + type_hint=torch.dtype, + description="The dtype of the model inputs" + ) + ] @property def intermediates_outputs(self) -> List[str]: - return [OutputParam("latents"), OutputParam("mask"), OutputParam("masked_image_latents"), OutputParam("noise")] + return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"), + OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for inpainting generation"), + OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting generation (only for inpainting-specific unet)"), + OutputParam("noise", type_hint=torch.Tensor, description="The noise added to the image latents, used for inpainting generation")] def __init__(self): super().__init__() @@ -735,26 +1000,50 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): expected_components = ["vae", "scheduler"] model_name = "stable-diffusion-xl" + @property + def description(self) -> str: + return ( + "Step that prepares the latents for the image-to-image generation process" + ) + @property def inputs(self) -> List[Tuple[str, Any]]: return [ - InputParam("generator"), - InputParam("latents"), - InputParam("num_images_per_prompt", default=1), - InputParam("denoising_start"), + InputParam( + "generator", + type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], + description="One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) " + "to make generation deterministic." + ), + InputParam( + "latents", + type_hint=Optional[torch.Tensor], + description="Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`." + ), + InputParam( + "num_images_per_prompt", + default=1, + type_hint=int, + description="The number of images to generate per prompt" + ), + InputParam( + "denoising_start", + type_hint=Optional[float], + description="When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be bypassed before it is initiated. The initial part of the denoising process is skipped and it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, strength will be ignored. Useful for 'Mixture of Denoisers' multi-pipeline setups." + ), ] @property - def intermediates_inputs(self) -> List[str]: + def intermediates_inputs(self) -> List[InputParam]: return [ - InputParam("latent_timestep", required=True), - InputParam("image_latents", required=True), - InputParam("batch_size", required=True), - InputParam("dtype")] + InputParam("latent_timestep", required=True, type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step."), + InputParam("image_latents", required=True, type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step."), + InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."), + InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs")] @property def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("latents")] + return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process")] def __init__(self): super().__init__() @@ -789,22 +1078,72 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock): model_name = "stable-diffusion-xl" @property - def inputs(self) -> List[Tuple[str, Any]]: + def description(self) -> str: + return ( + "Prepare latents step that prepares the latents for the text-to-image generation process" + ) + + @property + def inputs(self) -> List[InputParam]: return [ - InputParam("height"), - InputParam("width"), - InputParam("generator"), - InputParam("latents"), - InputParam("num_images_per_prompt", default=1), + InputParam( + "height", + type_hint=Optional[int], + description="The height in pixels of the generated image. This is set to 1024 by default for the best results. " + "Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0]" + "(https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not " + "specifically fine-tuned on low resolutions."), + InputParam( + "width", + type_hint=Optional[int], + description="The width in pixels of the generated image. This is set to 1024 by default for the best results. " + "Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0]" + "(https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not " + "specifically fine-tuned on low resolutions."), + InputParam( + "generator", + type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], + description="One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) " + "to make generation deterministic." + ), + InputParam( + "latents", + type_hint=Optional[torch.Tensor], + description="Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`." + ), + InputParam( + "num_images_per_prompt", + default=1, + type_hint=int, + description="The number of images to generate per prompt" + ), ] @property def intermediates_inputs(self) -> List[InputParam]: - return [InputParam("batch_size", required=True), InputParam("dtype")] + return [ + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + ), + InputParam( + "dtype", + type_hint=torch.dtype, + description="The dtype of the model inputs" + ) + ] @property def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("latents")] + return [ + OutputParam( + "latents", + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process" + ) + ] def __init__(self): super().__init__() @@ -856,32 +1195,99 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): expected_configs = ["requires_aesthetics_score"] model_name = "stable-diffusion-xl" + @property + def description(self) -> str: + return ( + "Step that prepares the additional conditioning for the image-to-image/inpainting generation process" + ) + @property def inputs(self) -> List[Tuple[str, Any]]: return [ - InputParam("original_size"), - InputParam("target_size"), - InputParam("negative_original_size"), - InputParam("negative_target_size"), - InputParam("crops_coords_top_left", default=(0, 0)), - InputParam("negative_crops_coords_top_left", default=(0, 0)), - InputParam("num_images_per_prompt", default=1), - InputParam("guidance_scale", default=5.0), - InputParam("aesthetic_score", default=6.0), - InputParam("negative_aesthetic_score", default=2.0), + InputParam( + "original_size", + type_hint=Optional[Tuple[int]], + description="If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. " + "`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as " + "explained in section 2.2 of https://huggingface.co/papers/2307.01952" + ), + InputParam( + "target_size", + type_hint=Optional[Tuple[int]], + description="For most cases, `target_size` should be set to the desired height and width of the generated image. If " + "not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in " + "section 2.2 of https://huggingface.co/papers/2307.01952" + ), + InputParam( + "negative_original_size", + type_hint=Optional[Tuple[int]], + description="To negatively condition the generation process based on a specific image resolution. Part of SDXL's " + "micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952" + ), + InputParam( + "negative_target_size", + type_hint=Optional[Tuple[int]], + description="To negatively condition the generation process based on a target image resolution. It should be as same " + "as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of " + "https://huggingface.co/papers/2307.01952" + ), + InputParam( + "crops_coords_top_left", + default=(0, 0), + type_hint=Tuple[int], + description="`crops_coords_top_left` can be used to generate an image that appears to be \"cropped\" from the position " + "`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning" + ), + InputParam( + "negative_crops_coords_top_left", + default=(0, 0), + type_hint=Tuple[int], + description="To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's " + "micro-conditioning" + ), + InputParam( + "num_images_per_prompt", + default=1, + type_hint=int, + description="The number of images to generate per prompt." + ), + InputParam( + "guidance_scale", + default=5.0, + type_hint=float, + description="Guidance scale as defined in Classifier-Free Diffusion Guidance. `guidance_scale` is defined as `w` of equation 2. " + "Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, " + "usually at the expense of lower image quality." + ), + InputParam( + "aesthetic_score", + default=6.0, + type_hint=float, + description="Used to simulate an aesthetic score of the generated image by influencing the positive text condition. " + "Part of SDXL's micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952" + ), + InputParam( + "negative_aesthetic_score", + default=2.0, + type_hint=float, + description="Part of SDXL's micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952. " + "Can be used to simulate an aesthetic score of the generated image by influencing the negative text condition." + ), ] @property def intermediates_inputs(self) -> List[InputParam]: return [ - InputParam("latents", required=True), - InputParam("pooled_prompt_embeds", required=True), - InputParam("batch_size", required=True), + InputParam("latents", required=True, type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."), + InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor, description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step."), + InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."), ] @property def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("add_time_ids"), OutputParam("negative_add_time_ids"), OutputParam("timestep_cond")] + return [OutputParam("add_time_ids", type_hint=torch.Tensor, description="The time ids to condition the denoising process"), + OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="The negative time ids to condition the denoising process"), + OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] def __init__(self): super().__init__() @@ -942,30 +1348,94 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): model_name = "stable-diffusion-xl" + @property + def description(self) -> str: + return ( + "Step that prepares the additional conditioning for the text-to-image generation process" + ) + @property def inputs(self) -> List[Tuple[str, Any]]: return [ - InputParam("original_size"), - InputParam("target_size"), - InputParam("negative_original_size"), - InputParam("negative_target_size"), - InputParam("crops_coords_top_left", default=(0, 0)), - InputParam("negative_crops_coords_top_left", default=(0, 0)), - InputParam("num_images_per_prompt", default=1), - InputParam("guidance_scale", default=5.0), + InputParam( + "original_size", + type_hint=Tuple[int, int], + default=(1024, 1024), + description="The original size (height, width) of the image that conditions the generation process. If different from target_size, the image will appear to be down- or upsampled. Part of SDXL's micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952" + ), + InputParam( + "target_size", + type_hint=Tuple[int, int], + default=(1024, 1024), + description="The target size (height, width) of the generated image. For most cases, this should be set to the desired output dimensions. Part of SDXL's micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952" + ), + InputParam( + "negative_original_size", + type_hint=Tuple[int, int], + default=(1024, 1024), + description="The negative original size to condition against during generation. Part of SDXL's micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952. See: https://github.com/huggingface/diffusers/issues/4208" + ), + InputParam( + "negative_target_size", + type_hint=Tuple[int, int], + default=(1024, 1024), + description="The negative target size to condition against during generation. Should typically match target_size. Part of SDXL's micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952. See: https://github.com/huggingface/diffusers/issues/4208" + ), + InputParam( + "crops_coords_top_left", + default=(0, 0), + type_hint=Tuple[int, int], + description="The top-left coordinates (x, y) used to condition the generation process. Setting this to (0, 0) typically produces well-centered images. Part of SDXL's micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952" + ), + InputParam( + "negative_crops_coords_top_left", + default=(0, 0), + type_hint=Tuple[int, int], + description="The top-left coordinates (x, y) used to negatively condition the generation process. Part of SDXL's micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952. For more information, see: https://github.com/huggingface/diffusers/issues/4208" + ), + InputParam( + "num_images_per_prompt", + default=1, + type_hint=int, + description="The number of images to generate per prompt" + ), + InputParam( + "guidance_scale", + default=5.0, + type_hint=float, + description="Guidance scale as defined in Classifier-Free Diffusion Guidance. `guidance_scale` is defined as `w` of equation 2. " + "Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, " + "usually at the expense of lower image quality."), ] @property def intermediates_inputs(self) -> List[InputParam]: return [ - InputParam("latents", required=True), - InputParam("pooled_prompt_embeds", required=True), - InputParam("batch_size", required=True), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + InputParam( + "pooled_prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step." + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + ), ] @property def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("add_time_ids"), OutputParam("negative_add_time_ids"), OutputParam("timestep_cond")] + return [OutputParam("add_time_ids", type_hint=torch.Tensor, description="The time ids to condition the denoising process"), + OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="The negative time ids to condition the denoising process"), + OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] @torch.no_grad() def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: @@ -1022,43 +1492,158 @@ class StableDiffusionXLDenoiseStep(PipelineBlock): expected_components = ["unet", "scheduler", "guider"] model_name = "stable-diffusion-xl" + @property + def description(self) -> str: + return ( + "Step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process" + ) + @property def inputs(self) -> List[Tuple[str, Any]]: return [ - InputParam("guidance_scale", default=5.0), - InputParam("guidance_rescale", default=0.0), - InputParam("cross_attention_kwargs", default=None), - InputParam("generator", default=None), - InputParam("eta", default=0.0), - InputParam("guider_kwargs", default=None), - InputParam("num_images_per_prompt", default=1), + InputParam( + "guidance_scale", + type_hint=float, + default=5.0, + description="Guidance scale as defined in Classifier-Free Diffusion Guidance. Higher values encourage images closely linked to the text prompt, potentially at the expense of image quality. Enabled when > 1." + ), + InputParam( + "guidance_rescale", + type_hint=float, + default=0.0, + description="Guidance rescale factor (φ) to fix overexposure when using zero terminal SNR, as proposed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed'." + ), + InputParam( + "cross_attention_kwargs", + type_hint=Optional[Dict[str, Any]], + default=None, + description="Optional kwargs dictionary passed to the AttentionProcessor." + ), + InputParam( + "generator", + type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], + description="One or a list of torch generator(s) to make generation deterministic." + ), + InputParam( + "eta", + type_hint=float, + default=0.0, + description="Parameter η in the DDIM paper. Only applies to DDIMScheduler, ignored for others." + ), + InputParam( + "guider_kwargs", + type_hint=Optional[Dict[str, Any]], + default=None, + description="Optional kwargs dictionary passed to the Guider." + ), + InputParam( + "num_images_per_prompt", + type_hint=int, + default=1, + description="The number of images to generate per prompt." + ), ] @property def intermediates_inputs(self) -> List[str]: return [ - InputParam("latents", required=True), - InputParam("batch_size", required=True), - InputParam("timesteps", required=True), - InputParam("num_inference_steps", required=True), - InputParam("pooled_prompt_embeds", required=True), - InputParam("negative_pooled_prompt_embeds"), - InputParam("add_time_ids", required=True), - InputParam("negative_add_time_ids"), - InputParam("prompt_embeds", required=True), - InputParam("negative_prompt_embeds"), - InputParam("timestep_cond"), # LCM - InputParam("mask"), # inpainting - InputParam("masked_image_latents"), # inpainting - InputParam("noise"), # inpainting - InputParam("image_latents"), # inpainting - InputParam("ip_adapter_embeds"), # ip adapter - InputParam("negative_ip_adapter_embeds"), # ip adapter + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + ), + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + "pooled_prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="The pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." + ), + InputParam( + "negative_pooled_prompt_embeds", + type_hint=Optional[torch.Tensor], + description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " + ), + InputParam( + "add_time_ids", + required=True, + type_hint=torch.Tensor, + description="The time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." + ), + InputParam( + "negative_add_time_ids", + type_hint=Optional[torch.Tensor], + description="The negative time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." + ), + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="The prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." + ), + InputParam( + "negative_prompt_embeds", + type_hint=Optional[torch.Tensor], + description="The negative prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " + ), + InputParam( + "timestep_cond", + type_hint=Optional[torch.Tensor], + description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." + ), + InputParam( + "mask", + type_hint=Optional[torch.Tensor], + description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." + ), + InputParam( + "masked_image_latents", + type_hint=Optional[torch.Tensor], + description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." + ), + InputParam( + "noise", + type_hint=Optional[torch.Tensor], + description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." + ), + InputParam( + "image_latents", + type_hint=Optional[torch.Tensor], + description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." + ), + InputParam( + "ip_adapter_embeds", + type_hint=Optional[torch.Tensor], + description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." + ), + InputParam( + "negative_ip_adapter_embeds", + type_hint=Optional[torch.Tensor], + description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." + ), ] @property def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("latents")] + return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] def __init__(self): super().__init__() @@ -1194,49 +1779,192 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): expected_components = ["unet", "controlnet", "scheduler", "guider", "controlnet_guider"] model_name = "stable-diffusion-xl" + @property + def description(self) -> str: + return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" + @property def inputs(self) -> List[Tuple[str, Any]]: return [ - InputParam("control_image", required=True), - InputParam("control_guidance_start", default=0.0), - InputParam("control_guidance_end", default=1.0), - InputParam("controlnet_conditioning_scale", default=1.0), - InputParam("guess_mode", default=False), - InputParam("num_images_per_prompt", default=1), - InputParam("guidance_scale", default=5.0), - InputParam("guidance_rescale", default=0.0), - InputParam("cross_attention_kwargs", default=None), - InputParam("generator", default=None), - InputParam("eta", default=0.0), - InputParam("guider_kwargs", default=None), + InputParam( + "control_image", + required=True, + type_hint=PipelineImageInput, + description="The ControlNet input condition to provide guidance to the unet for generation. If passed as torch.Tensor, it is used as-is. PIL.Image.Image inputs are accepted and default to image dimensions. For multiple ControlNets, pass images as a list for proper batching." + ), + InputParam( + "control_guidance_start", + default=0.0, + type_hint=Union[float, List[float]], + description="The percentage of total steps at which the ControlNet starts applying." + ), + InputParam( + "control_guidance_end", + default=1.0, + type_hint=Union[float, List[float]], + description="The percentage of total steps at which the ControlNet stops applying." + ), + InputParam( + "controlnet_conditioning_scale", + default=1.0, + type_hint=Union[float, List[float]], + description="Scale factor for ControlNet outputs before adding to unet residual. For multiple ControlNets, can be set as a list of scales." + ), + InputParam( + "guess_mode", + default=False, + type_hint=bool, + description="Enables ControlNet encoder to recognize input image content without prompts. Recommended guidance_scale: 3.0-5.0." + ), + InputParam( + "num_images_per_prompt", + default=1, + type_hint=int, + description="The number of images to generate per prompt." + ), + InputParam( + "guidance_scale", + default=5.0, + type_hint=float, + description="Guidance scale as defined in Classifier-Free Diffusion Guidance. Higher values encourage images closely linked to the text prompt, potentially at the expense of image quality. Enabled when > 1." + ), + InputParam( + "guidance_rescale", + default=0.0, + type_hint=float, + description="Guidance rescale factor (φ) to fix overexposure when using zero terminal SNR, as proposed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed'." + ), + InputParam( + "cross_attention_kwargs", + default=None, + type_hint=Optional[Dict[str, Any]], + description="Optional kwargs dictionary passed to the AttentionProcessor." + ), + InputParam( + "generator", + default=None, + type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], + description="One or a list of torch generator(s) to make generation deterministic." + ), + InputParam( + "eta", + default=0.0, + type_hint=float, + description="Parameter η in the DDIM paper. Only applies to DDIMScheduler, ignored for others." + ), + InputParam( + "guider_kwargs", + default=None, + type_hint=Optional[Dict[str, Any]], + description="Optional kwargs dictionary passed to the Guider." + ), ] @property def intermediates_inputs(self) -> List[str]: return [ - InputParam("latents", required=True), - InputParam("batch_size", required=True), - InputParam("timesteps", required=True), - InputParam("num_inference_steps", required=True), - InputParam("prompt_embeds", required=True), - InputParam("negative_prompt_embeds"), - InputParam("add_time_ids", required=True), - InputParam("negative_add_time_ids"), - InputParam("pooled_prompt_embeds", required=True), - InputParam("negative_pooled_prompt_embeds"), - InputParam("timestep_cond"), # LCM - InputParam("mask"), # inpainting - InputParam("masked_image_latents"), # inpainting - InputParam("noise"), # inpainting - InputParam("image_latents"), # inpainting - InputParam("crops_coords"), # inpainting - InputParam("ip_adapter_embeds"), # ip adapter - InputParam("negative_ip_adapter_embeds"), # ip adapter + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + ), + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="The prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." + ), + InputParam( + "negative_prompt_embeds", + type_hint=Optional[torch.Tensor], + description="The negative prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." + ), + InputParam( + "add_time_ids", + required=True, + type_hint=torch.Tensor, + description="The time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." + ), + InputParam( + "negative_add_time_ids", + type_hint=Optional[torch.Tensor], + description="The negative time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." + ), + InputParam( + "pooled_prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="The pooled prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." + ), + InputParam( + "negative_pooled_prompt_embeds", + type_hint=Optional[torch.Tensor], + description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." + ), + InputParam( + "timestep_cond", + type_hint=Optional[torch.Tensor], + description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" + ), + InputParam( + "mask", + type_hint=Optional[torch.Tensor], + description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." + ), + InputParam( + "masked_image_latents", + type_hint=Optional[torch.Tensor], + description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." + ), + InputParam( + "noise", + type_hint=Optional[torch.Tensor], + description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." + ), + InputParam( + "image_latents", + type_hint=Optional[torch.Tensor], + description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." + ), + InputParam( + "crops_coords", + type_hint=Optional[Tuple[int]], + description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." + ), + InputParam( + "ip_adapter_embeds", + type_hint=Optional[torch.Tensor], + description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." + ), + InputParam( + "negative_ip_adapter_embeds", + type_hint=Optional[torch.Tensor], + description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." + ), ] @property def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("latents")] + return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] def __init__(self): super().__init__() @@ -1505,50 +2233,188 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): expected_components = ["unet", "controlnet", "scheduler", "guider", "controlnet_guider"] model_name = "stable-diffusion-xl" + @property + def description(self) -> str: + return " The denoising step for the controlnet union model, works for inpainting, image-to-image, and text-to-image tasks" @property def inputs(self) -> List[Tuple[str, Any]]: return [ - (InputParam("control_image", required=True)), - (InputParam("control_guidance_start", default=0.0)), - (InputParam("control_guidance_end", default=1.0)), - (InputParam("controlnet_conditioning_scale", default=1.0)), - (InputParam("control_mode", required=True)), - (InputParam("guess_mode", default=False)), - (InputParam("num_images_per_prompt", default=1)), - (InputParam("guidance_scale", default=5.0)), - (InputParam("guidance_rescale", default=0.0)), - (InputParam("cross_attention_kwargs")), - (InputParam("generator")), - (InputParam("eta", default=0.0)), - (InputParam("guider_kwargs")), + InputParam( + "control_image", + required=True, + type_hint=PipelineImageInput, + description="The ControlNet input condition to provide guidance to the unet for generation. If passed as torch.Tensor, it is used as-is. PIL.Image.Image inputs are accepted and default to image dimensions. For multiple ControlNets, pass images as a list for proper batching."), + InputParam( + "control_guidance_start", + default=0.0, + type_hint=Union[float, List[float]], + description="The percentage of total steps at which the ControlNet starts applying."), + InputParam( + "control_guidance_end", + default=1.0, + type_hint=Union[float, List[float]], + description="The percentage of total steps at which the ControlNet stops applying."), + InputParam( + "control_mode", + required=True, + type_hint=List[int], + description="The control mode for union controlnet, 0 for openpose, 1 for depth, 2 for hed/pidi/scribble/ted, 3 for canny/lineart/anime_lineart/mlsd, 4 for normal and 5 for segment" + ), + InputParam( + "controlnet_conditioning_scale", + default=1.0, + type_hint=Union[float, List[float]], + description="Scale factor for ControlNet outputs before adding to unet residual. For multiple ControlNets, can be set as a list of scales." + ), + InputParam( + "guess_mode", + default=False, + type_hint=bool, + description="Enables ControlNet encoder to recognize input image content without prompts. Recommended guidance_scale: 3.0-5.0." + ), + InputParam( + "num_images_per_prompt", + default=1, + type_hint=int, + description="The number of images to generate per prompt." + ), + InputParam( + "guidance_scale", + default=5.0, + type_hint=float, + description="Guidance scale as defined in Classifier-Free Diffusion Guidance. Higher values encourage images closely linked to the text prompt, potentially at the expense of image quality. Enabled when > 1."), + InputParam( + "guidance_rescale", + default=0.0, + type_hint=float, + description="Guidance rescale factor (φ) to fix overexposure when using zero terminal SNR, as proposed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed'."), + InputParam( + "cross_attention_kwargs", + default=None, + type_hint=Optional[Dict[str, Any]], + description="Optional kwargs dictionary passed to the AttentionProcessor."), + InputParam( + "generator", + default=None, + type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], + description="One or a list of torch generator(s) to make generation deterministic."), + InputParam( + "eta", + default=0.0, + type_hint=float, + description="Parameter η in the DDIM paper. Only applies to DDIMScheduler, ignored for others."), + InputParam( + "guider_kwargs", + default=None, + type_hint=Optional[Dict[str, Any]], + description="Optional kwargs dictionary passed to the Guider."), ] @property def intermediates_inputs(self) -> List[str]: return [ - InputParam("latents", required=True), - InputParam("batch_size", required=True), - InputParam("timesteps", required=True), - InputParam("num_inference_steps", required=True), - InputParam("prompt_embeds", required=True), - InputParam("negative_prompt_embeds"), - InputParam("add_time_ids", required=True), - InputParam("negative_add_time_ids"), - InputParam("pooled_prompt_embeds", required=True), - InputParam("negative_pooled_prompt_embeds"), - InputParam("timestep_cond"), # LCM - InputParam("mask"), # inpainting - InputParam("masked_image_latents"), # inpainting - InputParam("noise"), # inpainting - InputParam("image_latents"), # inpainting - InputParam("crops_coords"), # inpainting - InputParam("ip_adapter_embeds"), # ip adapter - InputParam("negative_ip_adapter_embeds"), # ip adapter + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + ), + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="The prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." + ), + InputParam( + "negative_prompt_embeds", + type_hint=Optional[torch.Tensor], + description="The negative prompt embeddings used to condition the denoising process. Can be generated in text_encoder step. See: https://github.com/huggingface/diffusers/issues/4208" + ), + InputParam( + "add_time_ids", + required=True, + type_hint=torch.Tensor, + description="The time ids used to condition the denoising process. Can be generated in prepare_additional_conditioning step." + ), + InputParam( + "negative_add_time_ids", + type_hint=Optional[torch.Tensor], + description="The negative time ids used to condition the denoising process. Can be generated in prepare_additional_conditioning step. " + ), + InputParam( + "pooled_prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="The pooled prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." + ), + InputParam( + "negative_pooled_prompt_embeds", + type_hint=Optional[torch.Tensor], + description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. See: https://github.com/huggingface/diffusers/issues/4208" + ), + InputParam( + "timestep_cond", + type_hint=Optional[torch.Tensor], + description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." + ), + InputParam( + "mask", + type_hint=Optional[torch.Tensor], + description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." + ), + InputParam( + "masked_image_latents", + type_hint=Optional[torch.Tensor], + description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." + ), + InputParam( + "noise", + type_hint=Optional[torch.Tensor], + description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." + ), + InputParam( + "image_latents", + type_hint=Optional[torch.Tensor], + description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." + ), + InputParam( + "crops_coords", + type_hint=Optional[Tuple[int]], + description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode or prepare_latent step." + ), + InputParam( + "ip_adapter_embeds", + type_hint=Optional[torch.Tensor], + description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." + ), + InputParam( + "negative_ip_adapter_embeds", + type_hint=Optional[torch.Tensor], + description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." + ), ] @property def intermediates_outputs(self) -> List[str]: - return [OutputParam("latents")] + return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] def __init__(self): super().__init__() @@ -1805,24 +2671,33 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: return pipeline, state + class StableDiffusionXLDecodeLatentsStep(PipelineBlock): expected_components = ["vae"] model_name = "stable-diffusion-xl" + @property + def description(self) -> str: + return "Step that decodes the denoised latents into images" + @property def inputs(self) -> List[Tuple[str, Any]]: return [ - (InputParam("output_type", default="pil")), - (InputParam("return_dict", default=True)), + InputParam( + "output_type", + type_hint=str, + default="pil", + description="The output format of the generated image. Choose between PIL (PIL.Image.Image), torch.Tensor or np.array." + ), ] @property def intermediates_inputs(self) -> List[str]: - return [InputParam("latents", required=True)] + return [InputParam("latents", required=True, type_hint=torch.Tensor, description="The denoised latents from the denoising step")] @property def intermediates_outputs(self) -> List[str]: - return [OutputParam("images")] + return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array")] def __init__(self): super().__init__() @@ -1886,21 +2761,44 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock): model_name = "stable-diffusion-xl" + @property + def description(self) -> str: + return "A post-processing step that overlays the mask on the image (inpainting task only)" + \ + "only needed when you are using the `padding_mask_crop` option when pre-processing the image and mask" + @property def inputs(self) -> List[Tuple[str, Any]]: return [ - (InputParam("image", required=True)), - (InputParam("mask_image", required=True)), - (InputParam("padding_mask_crop", default=None)), + InputParam( + "image", + type_hint=PipelineImageInput, + required=True, + description="The image(s) to modify with the pipeline, for img2img or inpainting task. When using for inpainting task, parts of the image will be masked out with `mask_image` and repainted according to `prompt`." + ), + InputParam( + "mask_image", + type_hint=PipelineImageInput, + required=True, + description="The mask image(s) to use for inpainting, white pixels in the mask will be repainted, while black pixels will be preserved. If mask_image is a PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the expected shape would be (B, H, W, 1). Must be a `PIL.Image.Image`" + ), + InputParam( + "padding_mask_crop", + type_hint=Optional[Tuple[int, int]], + default=None, + description="The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied. If set, it will find a rectangular region with the same aspect ratio as the image that contains all masked areas, then expand that area by this margin. The image and mask_image are cropped to this expanded area before resizing to the original size for inpainting. Useful when the masked area is small in a large image with irrelevant background information." + ), ] @property def intermediates_inputs(self) -> List[str]: - return [InputParam("images", required=True), InputParam("crops_coords")] + return [ + InputParam("images", required=True, type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images from the decode step"), + InputParam("crops_coords", required=True, type_hint=Tuple[int, int], description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.") + ] @property def intermediates_outputs(self) -> List[str]: - return [OutputParam("images")] + return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images with the mask overlayed")] @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: @@ -1917,21 +2815,25 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLOutputStep(PipelineBlock): model_name = "stable-diffusion-xl" + @property + def description(self) -> str: + return "final step to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or a plain tuple." + @property def inputs(self) -> List[Tuple[str, Any]]: - return [(InputParam("return_dict", default=True))] + return [(InputParam("return_dict", type_hint=bool, default=True, description="Whether or not to return a StableDiffusionXLPipelineOutput instead of a plain tuple."))] @property def intermediates_inputs(self) -> List[str]: - return [InputParam("images", required=True)] + return [InputParam("images", required=True, type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images from the decode step.")] @property def intermediates_outputs(self) -> List[str]: - return [OutputParam("images")] + return [OutputParam("images", description="The final images output, can be a tuple or a `StableDiffusionXLPipelineOutput`")] @property def outputs(self) -> List[Tuple[str, Any]]: - return [(OutputParam("images"))] + return [(OutputParam("images", type_hint=Union[Tuple[PIL.Image.Image], StableDiffusionXLPipelineOutput], description="The final images output, can be a tuple or a `StableDiffusionXLPipelineOutput`"))] @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: @@ -1945,54 +2847,131 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: return pipeline, state -class StableDiffusionXLDecodeStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLDecodeLatentsStep, StableDiffusionXLOutputStep] - block_names = ["decode", "output"] - - -class StableDiffusionXLInpaintDecodeStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLDecodeLatentsStep, StableDiffusionXLInpaintOverlayMaskStep, StableDiffusionXLOutputStep] - block_names = ["decode", "mask_overlay", "output"] - - -class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks): +# Encode +class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks): block_classes = [StableDiffusionXLInpaintVaeEncoderStep, StableDiffusionXLVaeEncoderStep] block_names = ["inpaint", "img2img"] block_trigger_inputs = ["mask_image", "image"] + @property + def description(self): + return "Vae encoder step that encode the image inputs into their latent representations.\n" + \ + "This is an auto pipeline block that works for both inpainting and img2img tasks.\n" + \ + " - `StableDiffusionXLInpaintVaeEncoderStep` (inpaint) is used when both `mask_image` and `image` are provided.\n" + \ + " - `StableDiffusionXLVaeEncoderStep` (img2img) is used when only `image` is provided." + +# Before denoise class StableDiffusionXLBeforeDenoiseStep(SequentialPipelineBlocks): block_classes = [StableDiffusionXLInputStep, StableDiffusionXLSetTimestepsStep, StableDiffusionXLPrepareLatentsStep, StableDiffusionXLPrepareAdditionalConditioningStep] block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"] + @property + def description(self): + return "Before denoise step that prepare the inputs for the denoise step.\n" + \ + "This is a sequential pipeline blocks:\n" + \ + " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ + " - `StableDiffusionXLSetTimestepsStep` is used to set the timesteps\n" + \ + " - `StableDiffusionXLPrepareLatentsStep` is used to prepare the latents\n" + \ + " - `StableDiffusionXLPrepareAdditionalConditioningStep` is used to prepare the additional conditioning" + class StableDiffusionXLImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks): block_classes = [StableDiffusionXLInputStep, StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLImg2ImgPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep] block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"] + @property + def description(self): + return "Before denoise step that prepare the inputs for the denoise step for img2img task.\n" + \ + "This is a sequential pipeline blocks:\n" + \ + " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ + " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" + \ + " - `StableDiffusionXLImg2ImgPrepareLatentsStep` is used to prepare the latents\n" + \ + " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning" + class StableDiffusionXLInpaintBeforeDenoiseStep(SequentialPipelineBlocks): block_classes = [StableDiffusionXLInputStep, StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLInpaintPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep] block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"] + @property + def description(self): + return "Before denoise step that prepare the inputs for the denoise step for inpainting task.\n" + \ + "This is a sequential pipeline blocks:\n" + \ + " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ + " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" + \ + " - `StableDiffusionXLInpaintPrepareLatentsStep` is used to prepare the latents\n" + \ + " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning" + class StableDiffusionXLAutoBeforeDenoiseStep(AutoPipelineBlocks): block_classes = [StableDiffusionXLInpaintBeforeDenoiseStep, StableDiffusionXLImg2ImgBeforeDenoiseStep, StableDiffusionXLBeforeDenoiseStep] block_names = ["inpaint", "img2img", "text2img"] block_trigger_inputs = ["mask", "image_latents", None] + @property + def description(self): + return "Before denoise step that prepare the inputs for the denoise step.\n" + \ + "This is an auto pipeline block that works for text2img, img2img and inpainting tasks.\n" + \ + " - `StableDiffusionXLInpaintBeforeDenoiseStep` (inpaint) is used when both `mask` and `image_latents` are provided.\n" + \ + " - `StableDiffusionXLImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n" + \ + " - `StableDiffusionXLBeforeDenoiseStep` (text2img) is used when both `image_latents` and `mask` are not provided." +# Denoise class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): block_classes = [StableDiffusionXLControlNetUnionDenoiseStep, StableDiffusionXLControlNetDenoiseStep, StableDiffusionXLDenoiseStep] block_names = ["controlnet_union", "controlnet", "unet"] block_trigger_inputs = ["control_mode", "control_image", None] + @property + def description(self): + return "Denoise step that denoise the latents.\n" + \ + "This is an auto pipeline block that works for controlnet, controlnet_union and no controlnet.\n" + \ + " - `StableDiffusionXLControlNetUnionDenoiseStep` (controlnet_union) is used when both `control_mode` and `control_image` are provided.\n" + \ + " - `StableDiffusionXLControlNetDenoiseStep` (controlnet) is used when `control_image` is provided.\n" + \ + " - `StableDiffusionXLDenoiseStep` (unet only) is used when both `control_mode` and `control_image` are not provided." + +# After denoise + +class StableDiffusionXLDecodeStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLDecodeLatentsStep, StableDiffusionXLOutputStep] + block_names = ["decode", "output"] + + @property + def description(self): + return "Decode step that decode the denoised latents into images outputs.\n" + \ + "This is a sequential pipeline blocks:\n" + \ + " - `StableDiffusionXLDecodeLatentsStep` is used to decode the denoised latents into images\n" + \ + " - `StableDiffusionXLOutputStep` is used to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or a plain tuple." + + +class StableDiffusionXLInpaintDecodeStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLDecodeLatentsStep, StableDiffusionXLInpaintOverlayMaskStep, StableDiffusionXLOutputStep] + block_names = ["decode", "mask_overlay", "output"] + + @property + def description(self): + return "Inpaint decode step that decode the denoised latents into images outputs.\n" + \ + "This is a sequential pipeline blocks:\n" + \ + " - `StableDiffusionXLDecodeLatentsStep` is used to decode the denoised latents into images\n" + \ + " - `StableDiffusionXLInpaintOverlayMaskStep` is used to overlay the mask on the image\n" + \ + " - `StableDiffusionXLOutputStep` is used to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or a plain tuple." + + class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks): block_classes = [StableDiffusionXLInpaintDecodeStep, StableDiffusionXLDecodeStep] block_names = ["inpaint", "non-inpaint"] block_trigger_inputs = ["padding_mask_crop", None] + @property + def description(self): + return "Decode step that decode the denoised latents into images outputs.\n" + \ + "This is an auto pipeline block that works for inpainting and non-inpainting tasks.\n" + \ + " - `StableDiffusionXLInpaintDecodeStep` (inpaint) is used when `padding_mask_crop` is provided.\n" + \ + " - `StableDiffusionXLDecodeStep` (non-inpaint) is used when `padding_mask_crop` is not provided." + +# block mapping TEXT2IMAGE_BLOCKS = OrderedDict([ ("text_encoder", StableDiffusionXLTextEncoderStep), ("input", StableDiffusionXLInputStep), From 00a3bc9d6c5cae631a746e194fdd232e3b0fb93e Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 23 Jan 2025 18:16:00 +0100 Subject: [PATCH 055/170] fix --- src/diffusers/loaders/lora_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index fa259a7104d2..0852cc7f2511 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -432,7 +432,7 @@ def _func_optionally_disable_offloading(_pipeline): is_model_cpu_offload = False is_sequential_cpu_offload = False - if _pipeline is not None and hasattr(_pipeline, hf_device_map) and _pipeline.hf_device_map is None: + if _pipeline is not None and hasattr(_pipeline, "hf_device_map") and _pipeline.hf_device_map is None: for _, component in _pipeline.components.items(): if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): if not is_model_cpu_offload: From 4bed3e306e374c2ab2f1459db8c98e0498330e5e Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 26 Jan 2025 13:04:33 +0100 Subject: [PATCH 056/170] up up --- src/diffusers/loaders/ip_adapter.py | 3 + src/diffusers/pipelines/components_manager.py | 105 +++++++++++++++--- 2 files changed, 93 insertions(+), 15 deletions(-) diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index 895dce22dc12..c690f777af7f 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -586,6 +586,9 @@ def unload_ip_adapter(self): """ # remove hidden encoder + if self.unet is None: + return + self.unet.encoder_hid_proj = None self.unet.config.encoder_hid_dim_type = None diff --git a/src/diffusers/pipelines/components_manager.py b/src/diffusers/pipelines/components_manager.py index d6a4b5958750..d183aabaeb4b 100644 --- a/src/diffusers/pipelines/components_manager.py +++ b/src/diffusers/pipelines/components_manager.py @@ -14,14 +14,17 @@ from collections import OrderedDict from itertools import combinations -from typing import List, Optional, Union +from typing import List, Optional, Union, Dict, Any import torch +import time +from dataclasses import dataclass from ..utils import ( is_accelerate_available, logging, ) +from ..models.modeling_utils import ModelMixin if is_accelerate_available(): @@ -95,9 +98,6 @@ def pre_forward(self, module, *args, **kwargs): if self.other_hooks is not None: hooks_to_offload = [hook for hook in self.other_hooks if hook.model.device == self.execution_device] # offload all other hooks - import time - - # YiYi Notes: only logging time for now to monitor the overhead of offloading strategy (remove later) start_time = time.perf_counter() if self.offload_strategy is not None: hooks_to_offload = self.offload_strategy( @@ -231,17 +231,27 @@ def search_best_candidate(module_sizes, min_memory_offload): class ComponentsManager: def __init__(self): self.components = OrderedDict() + self.added_time = OrderedDict() # Store when components were added self.model_hooks = None self._auto_offload_enabled = False def add(self, name, component): - if name not in self.components: - self.components[name] = component - if self._auto_offload_enabled: - self.enable_auto_cpu_offload(self._auto_offload_device) + if name in self.components: + logger.warning(f"Overriding existing component '{name}' in ComponentsManager") + self.components[name] = component + self.added_time[name] = time.time() + + if self._auto_offload_enabled: + self.enable_auto_cpu_offload(self._auto_offload_device) def remove(self, name): + if name not in self.components: + logger.warning(f"Component '{name}' not found in ComponentsManager") + return + self.components.pop(name) + self.added_time.pop(name) + if self._auto_offload_enabled: self.enable_auto_cpu_offload(self._auto_offload_device) @@ -294,6 +304,61 @@ def disable_auto_cpu_offload(self): self.model_hooks = None self._auto_offload_enabled = False + def get_model_info(self, name: str) -> Optional[Dict[str, Any]]: + """Get comprehensive information about a model component. + + Args: + name: Name of the component to get info for + + Returns: + Dictionary containing model metadata including: + - model_id: Name of the model + - class_name: Class name of the model + - device: Device the model is on + - dtype: Data type of the model + - size_gb: Size of the model in GB + - added_time: Timestamp when model was added + - active_adapters: List of active adapters (if applicable) + - attn_proc: List of attention processor types (if applicable) + Returns None if component is not a torch.nn.Module + """ + if name not in self.components: + raise ValueError(f"Component '{name}' not found in ComponentsManager") + + component = self.components[name] + + # Only process torch.nn.Module components + if not isinstance(component, torch.nn.Module): + return None + + info = { + "model_id": name, + "class_name": component.__class__.__name__, + "device": str(getattr(component, "device", "N/A")), + "dtype": str(component.dtype) if hasattr(component, "dtype") else None, + "added_time": self.added_time[name], + "size_gb": get_memory_footprint(component) / (1024**3), + "active_adapters": None, # Default to None + } + + # Get active adapters if applicable + if isinstance(component, ModelMixin): + from peft.tuners.tuners_utils import BaseTunerLayer + for module in component.modules(): + if isinstance(module, BaseTunerLayer): + info["active_adapters"] = module.active_adapters + break + + # Get attention processors if applicable + if hasattr(component, "attn_processors"): + processors = component.attn_processors + # Get unique processor types + processor_types = list(set(str(v.__class__.__name__) for v in processors.values())) + if processor_types: + info["attn_proc"] = processor_types + + return info + def __repr__(self): col_widths = { "id": max(15, max(len(id) for id in self.components.keys())), @@ -323,14 +388,12 @@ def __repr__(self): # Model entries for name, component in models.items(): - device = component.device - dtype = component.dtype - size_bytes = get_memory_footprint(component) - size_gb = size_bytes / (1024**3) - - output += f"{name:<{col_widths['id']}} | {component.__class__.__name__:<{col_widths['class']}} | " + info = self.get_model_info(name) + output += f"{name:<{col_widths['id']}} | {info['class_name']:<{col_widths['class']}} | " output += ( - f"{str(device):<{col_widths['device']}} | {str(dtype):<{col_widths['dtype']}} | {size_gb:.2f}\n" + f"{info['device']:<{col_widths['device']}} | " + f"{info['dtype']:<{col_widths['dtype']}} | " + f"{info['size_gb']:.2f}\n" ) output += dash_line @@ -348,6 +411,18 @@ def __repr__(self): output += f"{name:<{col_widths['id']}} | {component.__class__.__name__:<{col_widths['class']}}\n" output += dash_line + # Add additional component info + output += "\nAdditional Component Info:\n" + "=" * 50 + "\n" + for name in self.components: + info = self.get_model_info(name) + if info is not None and (info.get("active_adapters") is not None or info.get("attn_proc")): + output += f"\n{name}:\n" + if info.get("active_adapters") is not None: + output += f" Active Adapters: {info['active_adapters']}\n" + if info.get("attn_proc"): + output += f" Attention Processors: {info['attn_proc']}\n" + output += f" Added Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(info['added_time']))}\n" + return output def add_from_pretrained(self, pretrained_model_name_or_path, **kwargs): From c7020df2cf18c97f2db96551cc996615803e333b Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 27 Jan 2025 11:33:27 +0100 Subject: [PATCH 057/170] add model_info --- src/diffusers/pipelines/components_manager.py | 166 ++++--- .../pipelines/modular_pipeline_utils.py | 415 ++++++++++++++++++ 2 files changed, 531 insertions(+), 50 deletions(-) create mode 100644 src/diffusers/pipelines/modular_pipeline_utils.py diff --git a/src/diffusers/pipelines/components_manager.py b/src/diffusers/pipelines/components_manager.py index d183aabaeb4b..e261809ba5af 100644 --- a/src/diffusers/pipelines/components_manager.py +++ b/src/diffusers/pipelines/components_manager.py @@ -304,59 +304,66 @@ def disable_auto_cpu_offload(self): self.model_hooks = None self._auto_offload_enabled = False - def get_model_info(self, name: str) -> Optional[Dict[str, Any]]: - """Get comprehensive information about a model component. + def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = None) -> Optional[Dict[str, Any]]: + """Get comprehensive information about a component. Args: name: Name of the component to get info for - + fields: Optional field(s) to return. Can be a string for single field or list of fields. + If None, returns all fields. + Returns: - Dictionary containing model metadata including: - - model_id: Name of the model - - class_name: Class name of the model - - device: Device the model is on - - dtype: Data type of the model - - size_gb: Size of the model in GB - - added_time: Timestamp when model was added - - active_adapters: List of active adapters (if applicable) - - attn_proc: List of attention processor types (if applicable) - Returns None if component is not a torch.nn.Module + Dictionary containing requested component metadata. + If fields is specified, returns only those fields. + If a single field is requested as string, returns just that field's value. """ if name not in self.components: raise ValueError(f"Component '{name}' not found in ComponentsManager") component = self.components[name] - # Only process torch.nn.Module components - if not isinstance(component, torch.nn.Module): - return None - + # Build complete info dict first info = { "model_id": name, - "class_name": component.__class__.__name__, - "device": str(getattr(component, "device", "N/A")), - "dtype": str(component.dtype) if hasattr(component, "dtype") else None, "added_time": self.added_time[name], - "size_gb": get_memory_footprint(component) / (1024**3), - "active_adapters": None, # Default to None } - - # Get active adapters if applicable - if isinstance(component, ModelMixin): - from peft.tuners.tuners_utils import BaseTunerLayer - for module in component.modules(): - if isinstance(module, BaseTunerLayer): - info["active_adapters"] = module.active_adapters - break - - # Get attention processors if applicable - if hasattr(component, "attn_processors"): - processors = component.attn_processors - # Get unique processor types - processor_types = list(set(str(v.__class__.__name__) for v in processors.values())) - if processor_types: - info["attn_proc"] = processor_types - + + # Additional info for torch.nn.Module components + if isinstance(component, torch.nn.Module): + info.update({ + "class_name": component.__class__.__name__, + "size_gb": get_memory_footprint(component) / (1024**3), + "adapters": None, # Default to None + }) + + # Get adapters if applicable + if hasattr(component, "peft_config"): + info["adapters"] = list(component.peft_config.keys()) + + # Check for IP-Adapter scales + if hasattr(component, "_load_ip_adapter_weights") and hasattr(component, "attn_processors"): + processors = component.attn_processors + # First check if any processor is an IP-Adapter + processor_types = [v.__class__.__name__ for v in processors.values()] + if any("IPAdapter" in ptype for ptype in processor_types): + # Then get scales only from IP-Adapter processors + scales = { + k: v.scale + for k, v in processors.items() + if hasattr(v, "scale") and "IPAdapter" in v.__class__.__name__ + } + if scales: + info["ip_adapter"] = summarize_dict_by_value_and_parts(scales) + + # If fields specified, filter info + if fields is not None: + if isinstance(fields, str): + # Single field requested, return just that value + return {fields: info.get(fields)} + else: + # List of fields requested, return dict with just those fields + return {k: v for k, v in info.items() if k in fields} + return info def __repr__(self): @@ -383,18 +390,16 @@ def __repr__(self): output += "Models:\n" + dash_line # Column headers output += f"{'Model ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}} | " - output += f"{'Device':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | Size (GB) \n" + output += f"{'Device':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | Size (GB)\n" output += dash_line # Model entries for name, component in models.items(): info = self.get_model_info(name) + device = str(getattr(component, "device", "N/A")) + dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A" output += f"{name:<{col_widths['id']}} | {info['class_name']:<{col_widths['class']}} | " - output += ( - f"{info['device']:<{col_widths['device']}} | " - f"{info['dtype']:<{col_widths['dtype']}} | " - f"{info['size_gb']:.2f}\n" - ) + output += f"{device:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | {info['size_gb']:.2f}\n" output += dash_line # Other components section @@ -415,12 +420,12 @@ def __repr__(self): output += "\nAdditional Component Info:\n" + "=" * 50 + "\n" for name in self.components: info = self.get_model_info(name) - if info is not None and (info.get("active_adapters") is not None or info.get("attn_proc")): + if info is not None and (info.get("adapters") is not None or info.get("ip_adapter")): output += f"\n{name}:\n" - if info.get("active_adapters") is not None: - output += f" Active Adapters: {info['active_adapters']}\n" - if info.get("attn_proc"): - output += f" Attention Processors: {info['attn_proc']}\n" + if info.get("adapters") is not None: + output += f" Adapters: {info['adapters']}\n" + if info.get("ip_adapter"): + output += f" IP-Adapter: Enabled\n" output += f" Added Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(info['added_time']))}\n" return output @@ -438,3 +443,64 @@ def add_from_pretrained(self, pretrained_model_name_or_path, **kwargs): f"1. remove the existing component with remove('{name}')\n" f"2. Use a different name: add('{name}_2', component)" ) + +def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]: + """Summarizes a dictionary by finding common prefixes that share the same value. + + For a dictionary with dot-separated keys like: + { + 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': [0.6], + 'down_blocks.1.attentions.1.transformer_blocks.1.attn2.processor': [0.6], + 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': [0.3], + } + + Returns a dictionary where keys are the shortest common prefixes and values are their shared values: + { + 'down_blocks': [0.6], + 'up_blocks': [0.3] + } + """ + # First group by values - convert lists to tuples to make them hashable + value_to_keys = {} + for key, value in d.items(): + value_tuple = tuple(value) if isinstance(value, list) else value + if value_tuple not in value_to_keys: + value_to_keys[value_tuple] = [] + value_to_keys[value_tuple].append(key) + + def find_common_prefix(keys: List[str]) -> str: + """Find the shortest common prefix among a list of dot-separated keys.""" + if not keys: + return "" + if len(keys) == 1: + return keys[0] + + # Split all keys into parts + key_parts = [k.split('.') for k in keys] + + # Find how many initial parts are common + common_length = 0 + for parts in zip(*key_parts): + if len(set(parts)) == 1: # All parts at this position are the same + common_length += 1 + else: + break + + if common_length == 0: + return "" + + # Return the common prefix + return '.'.join(key_parts[0][:common_length]) + + # Create summary by finding common prefixes for each value group + summary = {} + for value_tuple, keys in value_to_keys.items(): + prefix = find_common_prefix(keys) + if prefix: # Only add if we found a common prefix + # Convert tuple back to list if it was originally a list + value = list(value_tuple) if isinstance(d[keys[0]], list) else value_tuple + summary[prefix] = value + else: + summary[""] = value # Use empty string if no common prefix + + return summary diff --git a/src/diffusers/pipelines/modular_pipeline_utils.py b/src/diffusers/pipelines/modular_pipeline_utils.py new file mode 100644 index 000000000000..7a2737f63fe1 --- /dev/null +++ b/src/diffusers/pipelines/modular_pipeline_utils.py @@ -0,0 +1,415 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import traceback +import warnings +from collections import OrderedDict +from dataclasses import dataclass, field +from typing import Any, Dict, List, Tuple, Union + + +import torch +from tqdm.auto import tqdm +import re + +from ..configuration_utils import ConfigMixin +from ..utils import ( + is_accelerate_available, + is_accelerate_version, + logging, +) +from .pipeline_loading_utils import _get_pipeline_class + + +if is_accelerate_available(): + import accelerate + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + + +@dataclass +class PipelineState: + """ + [`PipelineState`] stores the state of a pipeline. It is used to pass data between pipeline blocks. + """ + + inputs: Dict[str, Any] = field(default_factory=dict) + intermediates: Dict[str, Any] = field(default_factory=dict) + outputs: Dict[str, Any] = field(default_factory=dict) + + def add_input(self, key: str, value: Any): + self.inputs[key] = value + + def add_intermediate(self, key: str, value: Any): + self.intermediates[key] = value + + def add_output(self, key: str, value: Any): + self.outputs[key] = value + + def get_input(self, key: str, default: Any = None) -> Any: + return self.inputs.get(key, default) + + def get_intermediate(self, key: str, default: Any = None) -> Any: + return self.intermediates.get(key, default) + + def get_output(self, key: str, default: Any = None) -> Any: + if key in self.outputs: + return self.outputs[key] + elif key in self.intermediates: + return self.intermediates[key] + else: + return default + + def to_dict(self) -> Dict[str, Any]: + return {**self.__dict__, "inputs": self.inputs, "intermediates": self.intermediates, "outputs": self.outputs} + + def __repr__(self): + def format_value(v): + if hasattr(v, "shape") and hasattr(v, "dtype"): + return f"Tensor(dtype={v.dtype}, shape={v.shape})" + elif isinstance(v, list) and len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): + return f"[Tensor(dtype={v[0].dtype}, shape={v[0].shape}), ...]" + else: + return repr(v) + + inputs = "\n".join(f" {k}: {format_value(v)}" for k, v in self.inputs.items()) + intermediates = "\n".join(f" {k}: {format_value(v)}" for k, v in self.intermediates.items()) + outputs = "\n".join(f" {k}: {format_value(v)}" for k, v in self.outputs.items()) + + return ( + f"PipelineState(\n" + f" inputs={{\n{inputs}\n }},\n" + f" intermediates={{\n{intermediates}\n }},\n" + f" outputs={{\n{outputs}\n }}\n" + f")" + ) + + +@dataclass +class BlockState: + """ + Container for block state data with attribute access and formatted representation. + """ + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + def __repr__(self): + def format_value(v): + # Handle tensors directly + if hasattr(v, "shape") and hasattr(v, "dtype"): + return f"Tensor(dtype={v.dtype}, shape={v.shape})" + + # Handle lists of tensors + elif isinstance(v, list): + if len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): + shapes = [t.shape for t in v] + return f"List[{len(v)}] of Tensors with shapes {shapes}" + return repr(v) + + # Handle tuples of tensors + elif isinstance(v, tuple): + if len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): + shapes = [t.shape for t in v] + return f"Tuple[{len(v)}] of Tensors with shapes {shapes}" + return repr(v) + + # Handle dicts with tensor values + elif isinstance(v, dict): + if any(hasattr(val, "shape") and hasattr(val, "dtype") for val in v.values()): + shapes = {k: val.shape for k, val in v.items() if hasattr(val, "shape")} + return f"Dict of Tensors with shapes {shapes}" + return repr(v) + + # Default case + return repr(v) + + attributes = "\n".join(f" {k}: {format_value(v)}" for k, v in self.__dict__.items()) + return f"BlockState(\n{attributes}\n)" + + +@dataclass +class InputParam: + name: str + default: Any = None + required: bool = False + description: str = "" + type_hint: Any = Any + + def __repr__(self): + return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>" + +@dataclass +class OutputParam: + name: str + description: str = "" + type_hint: Any = Any + + def __repr__(self): + return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>" + + + +def format_inputs_short(inputs): + """ + Format input parameters into a string representation, with required params first followed by optional ones. + + Args: + inputs: List of input parameters with 'required' and 'name' attributes, and 'default' for optional params + + Returns: + str: Formatted string of input parameters + """ + required_inputs = [param for param in inputs if param.required] + optional_inputs = [param for param in inputs if not param.required] + + required_str = ", ".join(param.name for param in required_inputs) + optional_str = ", ".join(f"{param.name}={param.default}" for param in optional_inputs) + + inputs_str = required_str + if optional_str: + inputs_str = f"{inputs_str}, {optional_str}" if required_str else optional_str + + return inputs_str + + +def format_intermediates_short(intermediates_inputs: List[InputParam], required_intermediates_inputs: List[str], intermediates_outputs: List[OutputParam]) -> str: + """ + Formats intermediate inputs and outputs of a block into a string representation. + + Args: + block: Pipeline block with potential intermediates + + Returns: + str: Formatted string like "input1, Required(input2) -> output1, output2" + """ + # Handle inputs + input_parts = [] + + for inp in intermediates_inputs: + parts = [] + # Check if input is required + if inp.name in required_intermediates_inputs: + parts.append("Required") + + # Get base name or modified name + name = inp.name + if name in {out.name for out in intermediates_outputs}: + name = f"*{name}" + + # Combine Required() wrapper with possibly starred name + if parts: + input_parts.append(f"Required({name})") + else: + input_parts.append(name) + + # Handle outputs + output_parts = [] + outputs = [out.name for out in intermediates_outputs] + # Only show new outputs if we have inputs + inputs_set = {inp.name for inp in intermediates_inputs} + outputs = [out for out in outputs if out not in inputs_set] + output_parts.extend(outputs) + + # Combine with arrow notation if both inputs and outputs exist + if output_parts: + return f"-> {', '.join(output_parts)}" if not input_parts else f"{', '.join(input_parts)} -> {', '.join(output_parts)}" + elif input_parts: + return ', '.join(input_parts) + return "" + + +def format_params(params: List[Union[InputParam, OutputParam]], header: str = "Args", indent_level: int = 4, max_line_length: int = 115) -> str: + """Format a list of InputParam or OutputParam objects into a readable string representation. + + Args: + params: List of InputParam or OutputParam objects to format + header: Header text to use (e.g. "Args" or "Returns") + indent_level: Number of spaces to indent each parameter line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + + Returns: + A formatted string representing all parameters + """ + if not params: + return "" + + base_indent = " " * indent_level + param_indent = " " * (indent_level + 4) + desc_indent = " " * (indent_level + 8) + formatted_params = [] + + def get_type_str(type_hint): + if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Union: + types = [t.__name__ if hasattr(t, "__name__") else str(t) for t in type_hint.__args__] + return f"Union[{', '.join(types)}]" + return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint) + + def wrap_text(text: str, indent: str, max_length: int) -> str: + """Wrap text while preserving markdown links and maintaining indentation.""" + words = text.split() + lines = [] + current_line = [] + current_length = 0 + + for word in words: + word_length = len(word) + (1 if current_line else 0) + + if current_line and current_length + word_length > max_length: + lines.append(" ".join(current_line)) + current_line = [word] + current_length = len(word) + else: + current_line.append(word) + current_length += word_length + + if current_line: + lines.append(" ".join(current_line)) + + return f"\n{indent}".join(lines) + + # Add the header + formatted_params.append(f"{base_indent}{header}:") + + for param in params: + # Format parameter name and type + type_str = get_type_str(param.type_hint) if param.type_hint != Any else "" + param_str = f"{param_indent}{param.name} (`{type_str}`" + + # Add optional tag and default value if parameter is an InputParam and optional + if isinstance(param, InputParam): + if not param.required: + param_str += ", *optional*" + if param.default is not None: + param_str += f", defaults to {param.default}" + param_str += "):" + + # Add description on a new line with additional indentation and wrapping + if param.description: + desc = re.sub( + r'\[(.*?)\]\((https?://[^\s\)]+)\)', + r'[\1](\2)', + param.description + ) + wrapped_desc = wrap_text(desc, desc_indent, max_line_length) + param_str += f"\n{desc_indent}{wrapped_desc}" + + formatted_params.append(param_str) + + return "\n\n".join(formatted_params) + + +# Then update the original functions to use this combined version: +def format_input_params(input_params: List[InputParam], indent_level: int = 4, max_line_length: int = 115) -> str: + return format_params(input_params, "Args", indent_level, max_line_length) + + +def format_output_params(output_params: List[OutputParam], indent_level: int = 4, max_line_length: int = 115) -> str: + return format_params(output_params, "Returns", indent_level, max_line_length) + + +def make_doc_string(inputs, intermediates_inputs, intermediates_outputs, final_intermediates_outputs=None, description=""): + """ + Generates a formatted documentation string describing the pipeline block's parameters and structure. + + Returns: + str: A formatted string containing information about call parameters, intermediate inputs/outputs, + and final intermediate outputs. + """ + output = "" + + if description: + desc_lines = description.strip().split('\n') + aligned_desc = '\n'.join(' ' + line for line in desc_lines) + output += aligned_desc + "\n\n" + + output += format_input_params(inputs + intermediates_inputs, indent_level=2) + + # YiYi TODO: refactor to remove this and `outputs` attribute instead + if final_intermediates_outputs: + output += "\n\n" + output += format_output_params(final_intermediates_outputs, indent_level=2) + + if intermediates_outputs: + output += "\n\n------------------------\n" + intermediates_str = format_params(intermediates_outputs, "Intermediates Outputs", indent_level=2) + output += intermediates_str + + elif intermediates_outputs: + output +="\n\n" + output += format_output_params(intermediates_outputs, indent_level=2) + + + return output + + +def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]: + """ + Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if + current default value is None and new default value is not None. Warns if multiple non-None default values + exist for the same input. + + Args: + named_input_lists: List of tuples containing (block_name, input_param_list) pairs + + Returns: + List[InputParam]: Combined list of unique InputParam objects + """ + combined_dict = {} # name -> InputParam + value_sources = {} # name -> block_name + + for block_name, inputs in named_input_lists: + for input_param in inputs: + if input_param.name in combined_dict: + current_param = combined_dict[input_param.name] + if (current_param.default is not None and + input_param.default is not None and + current_param.default != input_param.default): + warnings.warn( + f"Multiple different default values found for input '{input_param.name}': " + f"{current_param.default} (from block '{value_sources[input_param.name]}') and " + f"{input_param.default} (from block '{block_name}'). Using {current_param.default}." + ) + if current_param.default is None and input_param.default is not None: + combined_dict[input_param.name] = input_param + value_sources[input_param.name] = block_name + else: + combined_dict[input_param.name] = input_param + value_sources[input_param.name] = block_name + + return list(combined_dict.values()) + + +def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> List[OutputParam]: + """ + Combines multiple lists of OutputParam objects from different blocks. For duplicate outputs, + keeps the first occurrence of each output name. + + Args: + named_output_lists: List of tuples containing (block_name, output_param_list) pairs + + Returns: + List[OutputParam]: Combined list of unique OutputParam objects + """ + combined_dict = {} # name -> OutputParam + + for block_name, outputs in named_output_lists: + for output_param in outputs: + if output_param.name not in combined_dict: + combined_dict[output_param.name] = output_param + + return list(combined_dict.values()) + + From 2c3e4eafa8b6b630cc1af2b320798c2a6f0931c7 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 29 Jan 2025 17:58:40 +0100 Subject: [PATCH 058/170] fix --- src/diffusers/pipelines/components_manager.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/components_manager.py b/src/diffusers/pipelines/components_manager.py index e261809ba5af..ebbe75c286d5 100644 --- a/src/diffusers/pipelines/components_manager.py +++ b/src/diffusers/pipelines/components_manager.py @@ -15,6 +15,7 @@ from collections import OrderedDict from itertools import combinations from typing import List, Optional, Union, Dict, Any +import copy import torch import time @@ -342,7 +343,7 @@ def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = No # Check for IP-Adapter scales if hasattr(component, "_load_ip_adapter_weights") and hasattr(component, "attn_processors"): - processors = component.attn_processors + processors = copy.deepcopy(component.attn_processors) # First check if any processor is an IP-Adapter processor_types = [v.__class__.__name__ for v in processors.values()] if any("IPAdapter" in ptype for ptype in processor_types): From e5089d702b47fee52f3500b73390bd0e18f63a3f Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 31 Jan 2025 21:55:45 +0100 Subject: [PATCH 059/170] update --- src/diffusers/pipelines/components_manager.py | 2 +- .../pipeline_stable_diffusion_xl_modular.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/components_manager.py b/src/diffusers/pipelines/components_manager.py index ebbe75c286d5..b6634075ec33 100644 --- a/src/diffusers/pipelines/components_manager.py +++ b/src/diffusers/pipelines/components_manager.py @@ -266,7 +266,7 @@ def get(self, names: Union[str, List[str]]): else: raise ValueError(f"Invalid type for names: {type(names)}") - def enable_auto_cpu_offload(self, device, memory_reserve_margin="3GB"): + def enable_auto_cpu_offload(self, device: Union[str, int, torch.device]="cuda", memory_reserve_margin="3GB"): for name, component in self.components.items(): if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"): remove_hook_from_module(component, recurse=True) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index f42c2359a426..d013892b3a09 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -3020,6 +3020,11 @@ def description(self): ("decode", StableDiffusionXLAutoDecodeStep) ]) +AUTO_CORE_BLOCKS = OrderedDict([ + ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), + ("denoise", StableDiffusionXLAutoDenoiseStep), +]) + SDXL_SUPPORTED_BLOCKS = { "text2img": TEXT2IMAGE_BLOCKS, From 8ddb20bfb8bb2dc3af3aa6eb9598b40ca11cfb61 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 1 Feb 2025 05:45:00 +0100 Subject: [PATCH 060/170] up --- src/diffusers/pipelines/modular_pipeline.py | 164 +++++++++--------- .../pipeline_stable_diffusion_xl_modular.py | 44 +++-- 2 files changed, 120 insertions(+), 88 deletions(-) diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index a4c6baad47f5..4baf906efb57 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -61,9 +61,6 @@ def add_input(self, key: str, value: Any): def add_intermediate(self, key: str, value: Any): self.intermediates[key] = value - def add_output(self, key: str, value: Any): - self.outputs[key] = value - def get_input(self, key: str, default: Any = None) -> Any: return self.inputs.get(key, default) @@ -194,45 +191,45 @@ def format_intermediates_short(intermediates_inputs: List[InputParam], required_ Formats intermediate inputs and outputs of a block into a string representation. Args: - block: Pipeline block with potential intermediates + intermediates_inputs: List of intermediate input parameters + required_intermediates_inputs: List of required intermediate input names + intermediates_outputs: List of intermediate output parameters Returns: - str: Formatted string like "input1, Required(input2) -> output1, output2" + str: Formatted string like: + Intermediates: + - inputs: Required(latents), dtype + - modified: latents # variables that appear in both inputs and outputs + - outputs: images # new outputs only """ # Handle inputs input_parts = [] - for inp in intermediates_inputs: - parts = [] - # Check if input is required if inp.name in required_intermediates_inputs: - parts.append("Required") - - # Get base name or modified name - name = inp.name - if name in {out.name for out in intermediates_outputs}: - name = f"*{name}" - - # Combine Required() wrapper with possibly starred name - if parts: - input_parts.append(f"Required({name})") + input_parts.append(f"Required({inp.name})") else: - input_parts.append(name) + input_parts.append(inp.name) - # Handle outputs - output_parts = [] - outputs = [out.name for out in intermediates_outputs] - # Only show new outputs if we have inputs + # Handle modified variables (appear in both inputs and outputs) inputs_set = {inp.name for inp in intermediates_inputs} - outputs = [out for out in outputs if out not in inputs_set] - output_parts.extend(outputs) + modified_parts = [] + new_output_parts = [] - # Combine with arrow notation if both inputs and outputs exist - if output_parts: - return f"-> {', '.join(output_parts)}" if not input_parts else f"{', '.join(input_parts)} -> {', '.join(output_parts)}" - elif input_parts: - return ', '.join(input_parts) - return "" + for out in intermediates_outputs: + if out.name in inputs_set: + modified_parts.append(out.name) + else: + new_output_parts.append(out.name) + + result = [] + if input_parts: + result.append(f" - inputs: {', '.join(input_parts)}") + if modified_parts: + result.append(f" - modified: {', '.join(modified_parts)}") + if new_output_parts: + result.append(f" - outputs: {', '.join(new_output_parts)}") + + return "\n".join(result) if result else " (none)" def format_params(params: List[Union[InputParam, OutputParam]], header: str = "Args", indent_level: int = 4, max_line_length: int = 115) -> str: @@ -323,7 +320,7 @@ def format_output_params(output_params: List[OutputParam], indent_level: int = 4 -def make_doc_string(inputs, intermediates_inputs, intermediates_outputs, final_intermediates_outputs=None, description=""): +def make_doc_string(inputs, intermediates_inputs, outputs, description=""): """ Generates a formatted documentation string describing the pipeline block's parameters and structure. @@ -340,20 +337,8 @@ def make_doc_string(inputs, intermediates_inputs, intermediates_outputs, final_i output += format_input_params(inputs + intermediates_inputs, indent_level=2) - # YiYi TODO: refactor to remove this and `outputs` attribute instead - if final_intermediates_outputs: - output += "\n\n" - output += format_output_params(final_intermediates_outputs, indent_level=2) - - if intermediates_outputs: - output += "\n\n------------------------\n" - intermediates_str = format_params(intermediates_outputs, "Intermediates Outputs", indent_level=2) - output += intermediates_str - - elif intermediates_outputs: - output +="\n\n" - output += format_output_params(intermediates_outputs, indent_level=2) - + output += "\n\n" + output += format_output_params(outputs, indent_level=2) return output @@ -367,23 +352,28 @@ class PipelineBlock: @property def description(self) -> str: - return "" + """Description of the block. Must be implemented by subclasses.""" + raise NotImplementedError("description method must be implemented in subclasses") @property def inputs(self) -> List[InputParam]: - return [] + """List of input parameters. Must be implemented by subclasses.""" + raise NotImplementedError("inputs method must be implemented in subclasses") @property def intermediates_inputs(self) -> List[InputParam]: - return [] + """List of intermediate input parameters. Must be implemented by subclasses.""" + raise NotImplementedError("intermediates_inputs method must be implemented in subclasses") @property def intermediates_outputs(self) -> List[OutputParam]: - return [] - + """List of intermediate output parameters. Must be implemented by subclasses.""" + raise NotImplementedError("intermediates_outputs method must be implemented in subclasses") + + # Adding outputs attributes here for consistency between PipelineBlock/AutoPipelineBlocks/SequentialPipelineBlocks @property def outputs(self) -> List[OutputParam]: - return [] + return self.intermediates_outputs @property def required_inputs(self) -> List[str]: @@ -413,7 +403,7 @@ def __repr__(self): class_name = self.__class__.__name__ base_class = self.__class__.__bases__[0].__name__ - # Components section - group into main components and auxiliaries if needed + # Components section expected_components = set(getattr(self, "expected_components", [])) loaded_components = set(self.components.keys()) all_components = sorted(expected_components | loaded_components) @@ -446,7 +436,7 @@ def __repr__(self): # Intermediates section intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs) - intermediates = f"Intermediates(`*` = modified):\n {intermediates_str}" + intermediates = f"Intermediates:\n{intermediates_str}" return ( f"{class_name}(\n" @@ -461,7 +451,7 @@ def __repr__(self): @property def doc(self): - return make_doc_string(self.inputs, self.intermediates_inputs, self.intermediates_outputs, None, self.description) + return make_doc_string(self.inputs, self.intermediates_inputs, self.outputs, self.description) def get_block_state(self, state: PipelineState) -> dict: @@ -489,11 +479,6 @@ def add_block_state(self, state: PipelineState, block_state: BlockState): if not hasattr(block_state, output_param.name): raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") state.add_intermediate(output_param.name, getattr(block_state, output_param.name)) - - for output_param in self.outputs: - if not hasattr(block_state, output_param.name): - raise ValueError(f"Output '{output_param.name}' is missing in block state") - state.add_output(output_param.name, getattr(block_state, output_param.name)) def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]: @@ -704,6 +689,12 @@ def intermediates_outputs(self) -> List[str]: named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] combined_outputs = combine_outputs(*named_outputs) return combined_outputs + + @property + def outputs(self) -> List[str]: + named_outputs = [(name, block.outputs) for name, block in self.blocks.items()] + combined_outputs = combine_outputs(*named_outputs) + return combined_outputs @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: @@ -803,10 +794,21 @@ def __repr__(self): sections.append(f" Block: {block.__class__.__name__}") inputs_str = format_inputs_short(block.inputs) - sections.append(f" inputs:\n {inputs_str}") + sections.append(f" inputs: {inputs_str}") - intermediates_str = f" intermediates(`*` = modified):\n {format_intermediates_short(block.intermediates_inputs, block.required_intermediates_inputs, block.intermediates_outputs)}" - sections.append(intermediates_str) + # Format intermediates with proper indentation + intermediates_str = format_intermediates_short( + block.intermediates_inputs, + block.required_intermediates_inputs, + block.intermediates_outputs + ) + if intermediates_str != " (none)": # Only add if there are intermediates + sections.append(" intermediates:") + # Add extra indentation to each line of intermediates + indented_intermediates = "\n".join( + " " + line for line in intermediates_str.split("\n") + ) + sections.append(indented_intermediates) sections.append("") @@ -819,7 +821,7 @@ def __repr__(self): @property def doc(self): - return make_doc_string(self.inputs, self.intermediates_inputs, self.intermediates_outputs, None, self.description) + return make_doc_string(self.inputs, self.intermediates_inputs, self.outputs, self.description) class SequentialPipelineBlocks: """ @@ -962,7 +964,7 @@ def intermediates_outputs(self) -> List[str]: return combined_outputs @property - def final_intermediates_outputs(self) -> List[str]: + def outputs(self) -> List[str]: return next(reversed(self.blocks.values())).intermediates_outputs @torch.no_grad() @@ -1121,28 +1123,34 @@ def __repr__(self): for i, (name, block) in enumerate(self.blocks.items()): blocks_str += f" {i}. {name} ({block.__class__.__name__})\n" + # Format inputs inputs_str = format_inputs_short(block.inputs) - blocks_str += f" inputs: {inputs_str}\n" - intermediates_str = format_intermediates_short(block.intermediates_inputs, block.required_intermediates_inputs, block.intermediates_outputs) - - if intermediates_str: - blocks_str += f" intermediates(`*` = modified): {intermediates_str}\n" + # Format intermediates with proper indentation + intermediates_str = format_intermediates_short( + block.intermediates_inputs, + block.required_intermediates_inputs, + block.intermediates_outputs + ) + if intermediates_str != " (none)": # Only add if there are intermediates + blocks_str += " intermediates:\n" + # Add extra indentation to each line of intermediates + indented_intermediates = "\n".join( + " " + line for line in intermediates_str.split("\n") + ) + blocks_str += f"{indented_intermediates}\n" blocks_str += "\n" inputs_str = format_inputs_short(self.inputs) inputs_str = " Inputs:\n " + inputs_str - final_intermediates_outputs = [out.name for out in self.final_intermediates_outputs] + outputs = [out.name for out in self.outputs] - intermediates_str_short = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs) - intermediates_input_str = intermediates_str_short.split('->')[0].strip() # "Required(latents), crops_coords" - intermediates_output_str = intermediates_str_short.split('->')[1].strip() + intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs) intermediates_str = ( "\n Intermediates:\n" - f" - inputs: {intermediates_input_str}\n" - f" - outputs: {intermediates_output_str}\n" - f" - final outputs: {', '.join(final_intermediates_outputs)}" + f"{intermediates_str}\n" + f" - final outputs: {', '.join(outputs)}" ) return ( @@ -1158,7 +1166,7 @@ def __repr__(self): @property def doc(self): - return make_doc_string(self.inputs, self.intermediates_inputs, self.intermediates_outputs, self.final_intermediates_outputs, self.description) + return make_doc_string(self.inputs, self.intermediates_inputs, self.outputs, self.description) class ModularPipeline(ConfigMixin): """ diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index d013892b3a09..02fcf380459e 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -139,6 +139,19 @@ def description(self) -> str: " for more details" ) + + @property + def inputs(self) -> List[InputParam]: + return [] + + @property + def intermediates_inputs(self) -> List[InputParam]: + return [] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [] + def __init__(self): super().__init__() self.components["text_encoder"] = None @@ -178,11 +191,17 @@ def inputs(self) -> List[InputParam]: description="Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). Guidance scale is enabled by setting `guidance_scale > 1`." ), ] - + @property - def intermediates_outputs(self) -> List[str]: - return [OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"), - OutputParam("negative_ip_adapter_embeds", type_hint=torch.Tensor, description="Negative IP adapter image embeddings")] + def intermediates_inputs(self) -> List[InputParam]: + return [] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [ + OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"), + OutputParam("negative_ip_adapter_embeds", type_hint=torch.Tensor, description="Negative IP adapter image embeddings") + ] def __init__(self): super().__init__() @@ -270,7 +289,11 @@ def inputs(self) -> List[InputParam]: @property - def intermediates_outputs(self) -> List[str]: + def intermediates_inputs(self) -> List[InputParam]: + return [] + + @property + def intermediates_outputs(self) -> List[OutputParam]: return [ OutputParam("prompt_embeds", type_hint=torch.Tensor, description="text embeddings used to guide the image generation"), OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="negative text embeddings used to guide the image generation"), @@ -378,13 +401,13 @@ def inputs(self) -> List[InputParam]: ] @property - def intermediates_inputs(self) -> List[str]: + def intermediates_inputs(self) -> List[InputParam]: return [ InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), InputParam("preprocess_kwargs", type_hint=Optional[dict], description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]")] @property - def intermediates_outputs(self) -> List[str]: + def intermediates_outputs(self) -> List[OutputParam]: return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation")] def __init__(self): @@ -818,6 +841,10 @@ def intermediates_outputs(self) -> List[OutputParam]: return [OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time")] + @property + def intermediates_inputs(self) -> List[InputParam]: + return [] + def __init__(self): super().__init__() self.components["scheduler"] = None @@ -2831,9 +2858,6 @@ def intermediates_inputs(self) -> List[str]: def intermediates_outputs(self) -> List[str]: return [OutputParam("images", description="The final images output, can be a tuple or a `StableDiffusionXLPipelineOutput`")] - @property - def outputs(self) -> List[Tuple[str, Any]]: - return [(OutputParam("images", type_hint=Union[Tuple[PIL.Image.Image], StableDiffusionXLPipelineOutput], description="The final images output, can be a tuple or a `StableDiffusionXLPipelineOutput`"))] @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: From cff0fd62608a3cd060b942702ade49fdd599384b Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 1 Feb 2025 11:36:13 +0100 Subject: [PATCH 061/170] more refactor --- src/diffusers/pipelines/modular_pipeline.py | 83 ++++++++++++++----- .../pipeline_stable_diffusion_xl_modular.py | 20 +++-- 2 files changed, 76 insertions(+), 27 deletions(-) diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index 4baf906efb57..4400b500b666 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -53,7 +53,6 @@ class PipelineState: inputs: Dict[str, Any] = field(default_factory=dict) intermediates: Dict[str, Any] = field(default_factory=dict) - outputs: Dict[str, Any] = field(default_factory=dict) def add_input(self, key: str, value: Any): self.inputs[key] = value @@ -64,19 +63,17 @@ def add_intermediate(self, key: str, value: Any): def get_input(self, key: str, default: Any = None) -> Any: return self.inputs.get(key, default) + def get_inputs(self, keys: List[str], default: Any = None) -> Dict[str, Any]: + return {key: self.inputs.get(key, default) for key in keys} + def get_intermediate(self, key: str, default: Any = None) -> Any: return self.intermediates.get(key, default) - def get_output(self, key: str, default: Any = None) -> Any: - if key in self.outputs: - return self.outputs[key] - elif key in self.intermediates: - return self.intermediates[key] - else: - return default + def get_intermediates(self, keys: List[str], default: Any = None) -> Dict[str, Any]: + return {key: self.intermediates.get(key, default) for key in keys} def to_dict(self) -> Dict[str, Any]: - return {**self.__dict__, "inputs": self.inputs, "intermediates": self.intermediates, "outputs": self.outputs} + return {**self.__dict__, "inputs": self.inputs, "intermediates": self.intermediates} def __repr__(self): def format_value(v): @@ -89,13 +86,11 @@ def format_value(v): inputs = "\n".join(f" {k}: {format_value(v)}" for k, v in self.inputs.items()) intermediates = "\n".join(f" {k}: {format_value(v)}" for k, v in self.intermediates.items()) - outputs = "\n".join(f" {k}: {format_value(v)}" for k, v in self.outputs.items()) return ( f"PipelineState(\n" f" inputs={{\n{inputs}\n }},\n" - f" intermediates={{\n{intermediates}\n }},\n" - f" outputs={{\n{outputs}\n }}\n" + f" intermediates={{\n{intermediates}\n }}\n" f")" ) @@ -403,6 +398,16 @@ def __repr__(self): class_name = self.__class__.__name__ base_class = self.__class__.__bases__[0].__name__ + # Format description with proper indentation + desc_lines = self.description.split('\n') + desc = [] + # First line with "Description:" label + desc.append(f" Description: {desc_lines[0]}") + # Subsequent lines with proper indentation + if len(desc_lines) > 1: + desc.extend(f" {line}" for line in desc_lines[1:]) + desc = '\n'.join(desc) + '\n' + # Components section expected_components = set(getattr(self, "expected_components", [])) loaded_components = set(self.components.keys()) @@ -441,6 +446,7 @@ def __repr__(self): return ( f"{class_name}(\n" f" Class: {base_class}\n" + f"{desc}" f" {components}\n" f" {configs}\n" f" {inputs}\n" @@ -760,16 +766,33 @@ def __repr__(self): class_name = self.__class__.__name__ base_class = self.__class__.__bases__[0].__name__ - all_triggers = set(self.trigger_to_block_map.keys()) - + # Format description with proper indentation + desc_lines = self.description.split('\n') + desc = [] + # First line with "Description:" label + desc.append(f" Description: {desc_lines[0]}") + # Subsequent lines with proper indentation + if len(desc_lines) > 1: + desc.extend(f" {line}" for line in desc_lines[1:]) + desc = '\n'.join(desc) + '\n' + sections = [] + all_triggers = set(self.trigger_to_block_map.keys()) for trigger in sorted(all_triggers, key=lambda x: str(x)): sections.append(f"\n Trigger Input: {trigger}\n") block = self.trigger_to_block_map.get(trigger) if block is None: continue - + + # Add block description with proper indentation + desc_lines = block.description.split('\n') + # First line starts right after "Description:", subsequent lines get indented + indented_desc = desc_lines[0] + if len(desc_lines) > 1: + indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) # Align with first line + sections.append(f" Description: {indented_desc}\n") + expected_components = set(getattr(block, "expected_components", [])) loaded_components = set(k for k, v in self.components.items() if v is not None and hasattr(block, k)) @@ -815,6 +838,7 @@ def __repr__(self): return ( f"{class_name}(\n" f" Class: {base_class}\n" + f"{desc}" f"{chr(10).join(sections)}" f")" ) @@ -1097,6 +1121,16 @@ def __repr__(self): header += " " + "=" * 100 + "\n" # Add decorative line header += "\n" # Add empty line after + # Format description with proper indentation + desc_lines = self.description.split('\n') + desc = [] + # First line with "Description:" label + desc.append(f" Description: {desc_lines[0]}") + # Subsequent lines with proper indentation + if len(desc_lines) > 1: + desc.extend(f" {line}" for line in desc_lines[1:]) + desc = '\n'.join(desc) + '\n' + # Components section expected_components = set(getattr(self, "expected_components", [])) loaded_components = set(self.components.keys()) @@ -1122,6 +1156,13 @@ def __repr__(self): blocks_str = " Blocks:\n" for i, (name, block) in enumerate(self.blocks.items()): blocks_str += f" {i}. {name} ({block.__class__.__name__})\n" + + desc_lines = block.description.split('\n') + # First line starts right after "Description:", subsequent lines get indented + indented_desc = desc_lines[0] + if len(desc_lines) > 1: + indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) # Align with first line + blocks_str += f" Description: {indented_desc}\n" # Format inputs inputs_str = format_inputs_short(block.inputs) @@ -1155,6 +1196,7 @@ def __repr__(self): return ( f"{header}\n" + f"{desc}" f"{components_str}\n" f"{auxiliaries_str}\n" f"{configs_str}\n" @@ -1329,13 +1371,12 @@ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = if output is None: return state - if isinstance(output, str): - return state.get_output(output) + + elif isinstance(output, str): + return state.get_intermediate(output) + elif isinstance(output, (list, tuple)): - outputs = {} - for output_name in output: - outputs[output_name] = state.get_output(output_name) - return outputs + return state.get_intermediates(output) else: raise ValueError(f"Output '{output}' is not a valid output type") diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 02fcf380459e..093e6cfac101 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -693,7 +693,7 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): @property def description(self) -> str: return ( - "Step that sets the timesteps for the scheduler and determines the initial noise level (latent_timestep) for image-to-image/inpainting generation." + "Step that sets the timesteps for the scheduler and determines the initial noise level (latent_timestep) for image-to-image/inpainting generation.\n" + \ "The latent_timestep is calculated from the `strength` parameter - higher strength means starting from a noisier version of the input image." ) @@ -2790,7 +2790,7 @@ class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock): @property def description(self) -> str: - return "A post-processing step that overlays the mask on the image (inpainting task only)" + \ + return "A post-processing step that overlays the mask on the image (inpainting task only).\n" + \ "only needed when you are using the `padding_mask_crop` option when pre-processing the image and mask" @property @@ -2962,10 +2962,10 @@ class StableDiffusionXLDecodeStep(SequentialPipelineBlocks): @property def description(self): - return "Decode step that decode the denoised latents into images outputs.\n" + \ - "This is a sequential pipeline blocks:\n" + \ - " - `StableDiffusionXLDecodeLatentsStep` is used to decode the denoised latents into images\n" + \ - " - `StableDiffusionXLOutputStep` is used to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or a plain tuple." + return """Decode step that decode the denoised latents into images outputs. +This is a sequential pipeline blocks: + - `StableDiffusionXLDecodeLatentsStep` is used to decode the denoised latents into images + - `StableDiffusionXLOutputStep` is used to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or a plain tuple.""" class StableDiffusionXLInpaintDecodeStep(SequentialPipelineBlocks): @@ -2995,6 +2995,14 @@ def description(self): " - `StableDiffusionXLDecodeStep` (non-inpaint) is used when `padding_mask_crop` is not provided." +class StableDiffusionAutoPipeline(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoVaeEncoderStep, StableDiffusionXLAutoBeforeDenoiseStep, StableDiffusionXLAutoDenoiseStep, StableDiffusionXLAutoDecodeStep] + block_names = ["text_encoder", "image_encoder", "before_denoise", "denoise", "decode"] + + @property + def description(self): + return "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion XL.\n" + # block mapping TEXT2IMAGE_BLOCKS = OrderedDict([ ("text_encoder", StableDiffusionXLTextEncoderStep), From 485f8d175830423d0639878b55292bb5aba4f2a6 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 1 Feb 2025 21:30:05 +0100 Subject: [PATCH 062/170] more refactor --- src/diffusers/pipelines/components_manager.py | 122 ++++++++++++++++-- src/diffusers/pipelines/modular_pipeline.py | 36 ++++-- .../pipeline_stable_diffusion_xl_modular.py | 22 +++- 3 files changed, 157 insertions(+), 23 deletions(-) diff --git a/src/diffusers/pipelines/components_manager.py b/src/diffusers/pipelines/components_manager.py index b6634075ec33..6d7665e29292 100644 --- a/src/diffusers/pipelines/components_manager.py +++ b/src/diffusers/pipelines/components_manager.py @@ -256,13 +256,99 @@ def remove(self, name): if self._auto_offload_enabled: self.enable_auto_cpu_offload(self._auto_offload_device) + # YiYi TODO: looking into improving the search pattern def get(self, names: Union[str, List[str]]): + """ + Get components by name with simple pattern matching. + + Args: + names: Component name(s) or pattern(s) + Patterns: + - "unet" : exact match + - "!unet" : everything except exact match "unet" + - "base_*" : everything starting with "base_" + - "!base_*" : everything NOT starting with "base_" + - "*unet*" : anything containing "unet" + - "!*unet*" : anything NOT containing "unet" + - "refiner|vae|unet" : anything containing any of these terms + - "!refiner|vae|unet" : anything NOT containing any of these terms + + Returns: + Single component if names is str and matches one component, + dict of components if names matches multiple components or is a list + """ if isinstance(names, str): - if names not in self.components: + # Check if this is a "not" pattern + is_not_pattern = names.startswith('!') + if is_not_pattern: + names = names[1:] # Remove the ! prefix + + # Handle OR patterns (containing |) + if '|' in names: + terms = names.split('|') + matches = { + name: comp for name, comp in self.components.items() + if any((term in name) != is_not_pattern for term in terms) # Flip condition if not pattern + } + if is_not_pattern: + logger.info(f"Getting components NOT containing any of {terms}: {list(matches.keys())}") + else: + logger.info(f"Getting components containing any of {terms}: {list(matches.keys())}") + + # Exact match + elif names in self.components: + if is_not_pattern: + matches = { + name: comp for name, comp in self.components.items() + if name != names + } + logger.info(f"Getting all components except '{names}': {list(matches.keys())}") + else: + logger.info(f"Getting component: {names}") + return self.components[names] + + # Prefix match (ends with *) + elif names.endswith('*'): + prefix = names[:-1] + matches = { + name: comp for name, comp in self.components.items() + if name.startswith(prefix) != is_not_pattern # Flip condition if not pattern + } + if is_not_pattern: + logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}") + else: + logger.info(f"Getting components starting with '{prefix}': {list(matches.keys())}") + + # Contains match (starts with *) + elif names.startswith('*'): + search = names[1:-1] if names.endswith('*') else names[1:] + matches = { + name: comp for name, comp in self.components.items() + if (search in name) != is_not_pattern # Flip condition if not pattern + } + if is_not_pattern: + logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}") + else: + logger.info(f"Getting components containing '{search}': {list(matches.keys())}") + + else: raise ValueError(f"Component '{names}' not found in ComponentsManager") - return self.components[names] + + if not matches: + raise ValueError(f"No components found matching pattern '{names}'") + return matches if len(matches) > 1 else next(iter(matches.values())) + elif isinstance(names, list): - return {n: self.components[n] for n in names} + results = {} + for name in names: + result = self.get(name) + if isinstance(result, dict): + results.update(result) + else: + results[name] = result + logger.info(f"Getting multiple components: {list(results.keys())}") + return results + else: raise ValueError(f"Invalid type for names: {type(names)}") @@ -431,18 +517,34 @@ def __repr__(self): return output - def add_from_pretrained(self, pretrained_model_name_or_path, **kwargs): + def add_from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs): + """ + Load components from a pretrained model and add them to the manager. + + Args: + pretrained_model_name_or_path (str): The path or identifier of the pretrained model + prefix (str, optional): Prefix to add to all component names loaded from this model. + If provided, components will be named as "{prefix}_{component_name}" + **kwargs: Additional arguments to pass to DiffusionPipeline.from_pretrained() + """ from ..pipelines.pipeline_utils import DiffusionPipeline pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs) for name, component in pipe.components.items(): - if name not in self.components and component is not None: - self.add(name, component) - elif name in self.components: + + if component is None: + continue + + # Add prefix if specified + component_name = f"{prefix}_{name}" if prefix else name + + if component_name not in self.components: + self.add(component_name, component) + else: logger.warning( - f"Component '{name}' already exists in ComponentsManager and will not be added. To add it, either:\n" - f"1. remove the existing component with remove('{name}')\n" - f"2. Use a different name: add('{name}_2', component)" + f"Component '{component_name}' already exists in ComponentsManager and will not be added. To add it, either:\n" + f"1. remove the existing component with remove('{component_name}')\n" + f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')" ) def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]: diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index 4400b500b666..6948930a6bbd 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -975,10 +975,16 @@ def intermediates_inputs(self) -> List[str]: for block in self.blocks.values(): # Add inputs that aren't in outputs yet inputs.extend(input_name for input_name in block.intermediates_inputs if input_name.name not in outputs) - # Add this block's outputs - block_intermediates_outputs = [out.name for out in block.intermediates_outputs] - outputs.update(block_intermediates_outputs) + # Only add outputs if the block cannot be skipped + should_add_outputs = True + if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs: + should_add_outputs = False + + if should_add_outputs: + # Add this block's outputs + block_intermediates_outputs = [out.name for out in block.intermediates_outputs] + outputs.update(block_intermediates_outputs) return inputs @property @@ -1035,47 +1041,59 @@ def trigger_inputs(self): return self._get_trigger_inputs() def _traverse_trigger_blocks(self, trigger_inputs): + # Convert trigger_inputs to a set for easier manipulation + active_triggers = set(trigger_inputs) - def fn_recursive_traverse(block, block_name, trigger_inputs): + def fn_recursive_traverse(block, block_name, active_triggers): result_blocks = OrderedDict() + # sequential or PipelineBlock if not hasattr(block, 'block_trigger_inputs'): if hasattr(block, 'blocks'): # sequential for block_name, block in block.blocks.items(): - blocks_to_update = fn_recursive_traverse(block, block_name, trigger_inputs) + blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers) result_blocks.update(blocks_to_update) else: # PipelineBlock result_blocks[block_name] = block + # Add this block's output names to active triggers if defined + if hasattr(block, 'outputs'): + active_triggers.update(out.name for out in block.outputs) return result_blocks # auto else: - # Find first block_trigger_input that matches any value in our trigger_value tuple + # Find first block_trigger_input that matches any value in our active_triggers this_block = None + matching_trigger = None for trigger_input in block.block_trigger_inputs: - if trigger_input is not None and trigger_input in trigger_inputs: + if trigger_input is not None and trigger_input in active_triggers: this_block = block.trigger_to_block_map[trigger_input] + matching_trigger = trigger_input break # If no matches found, try to get the default (None) block if this_block is None and None in block.block_trigger_inputs: this_block = block.trigger_to_block_map[None] + matching_trigger = None if this_block is not None: # sequential/auto if hasattr(this_block, 'blocks'): - result_blocks.update(fn_recursive_traverse(this_block, block_name, trigger_inputs)) + result_blocks.update(fn_recursive_traverse(this_block, block_name, active_triggers)) else: # PipelineBlock result_blocks[block_name] = this_block + # Add this block's output names to active triggers if defined + if hasattr(this_block, 'outputs'): + active_triggers.update(out.name for out in this_block.outputs) return result_blocks all_blocks = OrderedDict() for block_name, block in self.blocks.items(): - blocks_to_update = fn_recursive_traverse(block, block_name, trigger_inputs) + blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers) all_blocks.update(blocks_to_update) return all_blocks diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 093e6cfac101..e2d20f8a7ed0 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -2994,14 +2994,28 @@ def description(self): " - `StableDiffusionXLInpaintDecodeStep` (inpaint) is used when `padding_mask_crop` is provided.\n" + \ " - `StableDiffusionXLDecodeStep` (non-inpaint) is used when `padding_mask_crop` is not provided." +class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLIPAdapterStep] + block_names = ["ip_adapter"] + block_trigger_inputs = ["ip_adapter_image"] -class StableDiffusionAutoPipeline(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoVaeEncoderStep, StableDiffusionXLAutoBeforeDenoiseStep, StableDiffusionXLAutoDenoiseStep, StableDiffusionXLAutoDecodeStep] - block_names = ["text_encoder", "image_encoder", "before_denoise", "denoise", "decode"] + @property + def description(self): + return "Run IP Adapter step if `ip_adapter_image` is provided." + +class StableDiffusionXLAutoPipeline(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep, StableDiffusionXLAutoBeforeDenoiseStep, StableDiffusionXLAutoDenoiseStep, StableDiffusionXLAutoDecodeStep] + block_names = ["text_encoder", "ip_adapter", "image_encoder", "before_denoise", "denoise", "decode"] @property def description(self): - return "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion XL.\n" + return "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion XL.\n" + \ + "- for image-to-image generation, you need to provide either `image` or `image_latents`\n" + \ + "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n" + \ + "- to run the controlnet workflow, you need to provide `control_image`\n" + \ + "- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n" + \ + "- to run the ip_adapter workflow, you need to provide `ip_adapter_image`\n" + \ + "- for text-to-image generation, all you need to provide is `prompt`" # block mapping TEXT2IMAGE_BLOCKS = OrderedDict([ From addaad013cf84bff7a643252924eb838b729d40c Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 3 Feb 2025 20:36:05 +0100 Subject: [PATCH 063/170] more more more refactor --- src/diffusers/__init__.py | 2 + src/diffusers/pipelines/__init__.py | 2 + src/diffusers/pipelines/modular_pipeline.py | 207 +++++++++++------- .../pipelines/stable_diffusion_xl/__init__.py | 2 + 4 files changed, 138 insertions(+), 75 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 6233dcf3a5fd..5afdbc18d8e3 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -406,6 +406,7 @@ "StableDiffusionXLPAGInpaintPipeline", "StableDiffusionXLPAGPipeline", "StableDiffusionXLPipeline", + "StableDiffusionXLAutoPipeline", "StableUnCLIPImg2ImgPipeline", "StableUnCLIPPipeline", "StableVideoDiffusionPipeline", @@ -897,6 +898,7 @@ StableDiffusionXLPAGInpaintPipeline, StableDiffusionXLPAGPipeline, StableDiffusionXLPipeline, + StableDiffusionXLAutoPipeline, StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, StableVideoDiffusionPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 9c63637e2b5c..c20f92bb9d60 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -317,6 +317,7 @@ "StableDiffusionXLInstructPix2PixPipeline", "StableDiffusionXLPipeline", "StableDiffusionXLModularPipeline", + "StableDiffusionXLAutoPipeline", ] ) _import_structure["stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"] @@ -667,6 +668,7 @@ StableDiffusionXLInstructPix2PixPipeline, StableDiffusionXLModularPipeline, StableDiffusionXLPipeline, + StableDiffusionXLAutoPipeline, ) from .stable_video_diffusion import StableVideoDiffusionPipeline from .t2i_adapter import ( diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index 6948930a6bbd..b50d00dbc219 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -765,6 +765,22 @@ def trigger_inputs(self): def __repr__(self): class_name = self.__class__.__name__ base_class = self.__class__.__bases__[0].__name__ + header = ( + f"{class_name}(\n Class: {base_class}\n" + if base_class and base_class != "object" + else f"{class_name}(\n" + ) + + + if self.trigger_inputs: + header += "\n" + header += " " + "=" * 100 + "\n" + header += " This pipeline contains blocks that are selected at runtime based on inputs.\n" + header += f" Trigger Inputs: {self.trigger_inputs}\n" + # Get first trigger input as example + example_input = next(t for t in self.trigger_inputs if t is not None) + header += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n" + header += " " + "=" * 100 + "\n\n" # Format description with proper indentation desc_lines = self.description.split('\n') @@ -776,70 +792,92 @@ def __repr__(self): desc.extend(f" {line}" for line in desc_lines[1:]) desc = '\n'.join(desc) + '\n' - sections = [] - all_triggers = set(self.trigger_to_block_map.keys()) - for trigger in sorted(all_triggers, key=lambda x: str(x)): - sections.append(f"\n Trigger Input: {trigger}\n") - - block = self.trigger_to_block_map.get(trigger) - if block is None: - continue + # Components section + expected_components = set(getattr(self, "expected_components", [])) + loaded_components = set(self.components.keys()) + all_components = sorted(expected_components | loaded_components) + components_str = " Components:\n" + "\n".join( + f" - {k}={type(self.components[k]).__name__}" if k in loaded_components else f" - {k}" + for k in all_components + ) - # Add block description with proper indentation + # Auxiliaries section + auxiliaries_str = " Auxiliaries:\n" + "\n".join( + f" - {k}={type(v).__name__}" for k, v in self.auxiliaries.items() + ) + + # Configs section + expected_configs = set(getattr(self, "expected_configs", [])) + loaded_configs = set(self.configs.keys()) + all_configs = sorted(expected_configs | loaded_configs) + configs_str = " Configs:\n" + "\n".join( + f" - {k}={v}" if k in loaded_configs else f" - {k}" for k, v in self.configs.items() + ) + + blocks_str = " Blocks:\n" + for i, (name, block) in enumerate(self.blocks.items()): + # Get trigger input for this block + trigger = None + if hasattr(self, 'block_to_trigger_map'): + trigger = self.block_to_trigger_map.get(name) + # Format the trigger info + if trigger is None: + trigger_str = "[default]" + elif isinstance(trigger, (list, tuple)): + trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]" + else: + trigger_str = f"[trigger: {trigger}]" + # For AutoPipelineBlocks, add bullet points + blocks_str += f" • {name} {trigger_str} ({block.__class__.__name__})\n" + else: + # For SequentialPipelineBlocks, show execution order + blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" + + # Add block description desc_lines = block.description.split('\n') - # First line starts right after "Description:", subsequent lines get indented indented_desc = desc_lines[0] if len(desc_lines) > 1: - indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) # Align with first line - sections.append(f" Description: {indented_desc}\n") - - expected_components = set(getattr(block, "expected_components", [])) - loaded_components = set(k for k, v in self.components.items() - if v is not None and hasattr(block, k)) - all_components = sorted(expected_components | loaded_components) - if all_components: - sections.append(" Components:\n" + "\n".join( - f" - {k}={type(self.components[k]).__name__}" if k in loaded_components - else f" - {k}" for k in all_components - )) - - if self.auxiliaries: - sections.append(" Auxiliaries:\n" + "\n".join( - f" - {k}={type(v).__name__}" - for k, v in self.auxiliaries.items() - )) - - if self.configs: - sections.append(" Configs:\n" + "\n".join( - f" - {k}={v}" for k, v in self.configs.items() - )) - - sections.append(f" Block: {block.__class__.__name__}") - + indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) + blocks_str += f" Description: {indented_desc}\n" + + # Format inputs inputs_str = format_inputs_short(block.inputs) - sections.append(f" inputs: {inputs_str}") + blocks_str += f" inputs: {inputs_str}\n" - # Format intermediates with proper indentation + # Format intermediates intermediates_str = format_intermediates_short( - block.intermediates_inputs, - block.required_intermediates_inputs, + block.intermediates_inputs, + block.required_intermediates_inputs, block.intermediates_outputs ) - if intermediates_str != " (none)": # Only add if there are intermediates - sections.append(" intermediates:") - # Add extra indentation to each line of intermediates + if intermediates_str != " (none)": + blocks_str += " intermediates:\n" indented_intermediates = "\n".join( " " + line for line in intermediates_str.split("\n") ) - sections.append(indented_intermediates) - - sections.append("") + blocks_str += f"{indented_intermediates}\n" + blocks_str += "\n" + + inputs_str = format_inputs_short(self.inputs) + inputs_str = " Inputs:\n " + inputs_str + outputs = [out.name for out in self.outputs] + + intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs) + intermediates_str = ( + "\n Intermediates:\n" + f"{intermediates_str}\n" + f" - final outputs: {', '.join(outputs)}" + ) return ( - f"{class_name}(\n" - f" Class: {base_class}\n" + f"{header}\n" f"{desc}" - f"{chr(10).join(sections)}" + f"{components_str}\n" + f"{auxiliaries_str}\n" + f"{configs_str}\n" + f"{blocks_str}\n" + f"{inputs_str}\n" + f"{intermediates_str}\n" f")" ) @@ -1097,7 +1135,7 @@ def fn_recursive_traverse(block, block_name, active_triggers): all_blocks.update(blocks_to_update) return all_blocks - def get_triggered_blocks(self, *trigger_inputs): + def get_execution_blocks(self, *trigger_inputs): trigger_inputs_all = self.trigger_inputs if trigger_inputs is not None: @@ -1130,14 +1168,14 @@ def __repr__(self): if self.trigger_inputs: - header += "\n" # Add empty line before - header += " " + "=" * 100 + "\n" # Add decorative line - header += " This pipeline block contains dynamic blocks that are selected at runtime based on your inputs.\n" - header += " You can use `get_triggered_blocks(input1, input2,...)` to see which blocks will be used for your trigger inputs.\n" - header += " Use `get_triggered_blocks()` to see blocks will be used for default inputs (when no trigger inputs are provided)\n" + header += "\n" + header += " " + "=" * 100 + "\n" + header += " This pipeline contains blocks that are selected at runtime based on inputs.\n" header += f" Trigger Inputs: {self.trigger_inputs}\n" - header += " " + "=" * 100 + "\n" # Add decorative line - header += "\n" # Add empty line after + # Get first trigger input as example + example_input = next(t for t in self.trigger_inputs if t is not None) + header += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n" + header += " " + "=" * 100 + "\n\n" # Format description with proper indentation desc_lines = self.description.split('\n') @@ -1173,28 +1211,42 @@ def __repr__(self): blocks_str = " Blocks:\n" for i, (name, block) in enumerate(self.blocks.items()): - blocks_str += f" {i}. {name} ({block.__class__.__name__})\n" + # Get trigger input for this block + trigger = None + if hasattr(self, 'block_to_trigger_map'): + trigger = self.block_to_trigger_map.get(name) + # Format the trigger info + if trigger is None: + trigger_str = "[default]" + elif isinstance(trigger, (list, tuple)): + trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]" + else: + trigger_str = f"[trigger: {trigger}]" + # For AutoPipelineBlocks, add bullet points + blocks_str += f" • {name} {trigger_str} ({block.__class__.__name__})\n" + else: + # For SequentialPipelineBlocks, show execution order + blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" + # Add block description desc_lines = block.description.split('\n') - # First line starts right after "Description:", subsequent lines get indented indented_desc = desc_lines[0] if len(desc_lines) > 1: - indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) # Align with first line + indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) blocks_str += f" Description: {indented_desc}\n" # Format inputs inputs_str = format_inputs_short(block.inputs) blocks_str += f" inputs: {inputs_str}\n" - # Format intermediates with proper indentation + # Format intermediates intermediates_str = format_intermediates_short( - block.intermediates_inputs, - block.required_intermediates_inputs, + block.intermediates_inputs, + block.required_intermediates_inputs, block.intermediates_outputs ) - if intermediates_str != " (none)": # Only add if there are intermediates + if intermediates_str != " (none)": blocks_str += " intermediates:\n" - # Add extra indentation to each line of intermediates indented_intermediates = "\n".join( " " + line for line in intermediates_str.split("\n") ) @@ -1295,6 +1347,10 @@ def _execution_device(self): return torch.device(module._hf_hook.execution_device) return self.device + + def get_execution_blocks(self, *trigger_inputs): + return self.pipeline_block.get_execution_blocks(*trigger_inputs) + @property def dtype(self) -> torch.dtype: r""" @@ -1449,16 +1505,7 @@ def __repr__(self): block = self.pipeline_block - if hasattr(block, "trigger_inputs") and block.trigger_inputs: - output += "\n" - output += " Trigger Inputs:\n" - output += " --------------\n" - output += f" This pipeline contains dynamic blocks that are selected at runtime based on your inputs.\n" - output += f" • Trigger inputs: {block.trigger_inputs}\n" - output += f" • Use .pipeline_block.get_triggered_blocks(*inputs) to see which blocks will be used for specific inputs\n" - output += f" • Use .pipeline_block.get_triggered_blocks() to see blocks will be used for default inputs (when no trigger inputs are provided)\n" - output += "\n" - + # List the pipeline block structure first output += "Pipeline Block:\n" output += "--------------\n" if hasattr(block, "blocks"): @@ -1493,6 +1540,16 @@ def __repr__(self): output += f"{name}: {config!r}\n" output += "\n" + # Add auto blocks section + if hasattr(block, "trigger_inputs") and block.trigger_inputs: + output += "------------------\n" + output += "This pipeline contains blocks that are selected at runtime based on inputs.\n\n" + output += f"Trigger Inputs: {block.trigger_inputs}\n" + # Get first trigger input as example + example_input = next(t for t in block.trigger_inputs if t is not None) + output += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n" + output += "Check `.doc` of returned object for more information.\n\n" + # List the call parameters full_doc = self.pipeline_block.doc if "------------------------" in full_doc: diff --git a/src/diffusers/pipelines/stable_diffusion_xl/__init__.py b/src/diffusers/pipelines/stable_diffusion_xl/__init__.py index a1b821d1726f..584b260eaaa8 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/__init__.py @@ -39,6 +39,7 @@ "StableDiffusionXLPrepareLatentsStep", "StableDiffusionXLSetTimestepsStep", "StableDiffusionXLTextEncoderStep", + "StableDiffusionXLAutoPipeline", ] if is_transformers_available() and is_flax_available(): @@ -69,6 +70,7 @@ StableDiffusionXLPrepareLatentsStep, StableDiffusionXLSetTimestepsStep, StableDiffusionXLTextEncoderStep, + StableDiffusionXLAutoPipeline, ) try: From 12650e13934391ca8408d82a8df77bcb92032589 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 4 Feb 2025 02:08:28 +0100 Subject: [PATCH 064/170] up --- .../pipelines/modular_pipeline_utils.py | 415 ------------------ .../pipeline_stable_diffusion_xl_modular.py | 9 + 2 files changed, 9 insertions(+), 415 deletions(-) delete mode 100644 src/diffusers/pipelines/modular_pipeline_utils.py diff --git a/src/diffusers/pipelines/modular_pipeline_utils.py b/src/diffusers/pipelines/modular_pipeline_utils.py deleted file mode 100644 index 7a2737f63fe1..000000000000 --- a/src/diffusers/pipelines/modular_pipeline_utils.py +++ /dev/null @@ -1,415 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import traceback -import warnings -from collections import OrderedDict -from dataclasses import dataclass, field -from typing import Any, Dict, List, Tuple, Union - - -import torch -from tqdm.auto import tqdm -import re - -from ..configuration_utils import ConfigMixin -from ..utils import ( - is_accelerate_available, - is_accelerate_version, - logging, -) -from .pipeline_loading_utils import _get_pipeline_class - - -if is_accelerate_available(): - import accelerate - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - - -@dataclass -class PipelineState: - """ - [`PipelineState`] stores the state of a pipeline. It is used to pass data between pipeline blocks. - """ - - inputs: Dict[str, Any] = field(default_factory=dict) - intermediates: Dict[str, Any] = field(default_factory=dict) - outputs: Dict[str, Any] = field(default_factory=dict) - - def add_input(self, key: str, value: Any): - self.inputs[key] = value - - def add_intermediate(self, key: str, value: Any): - self.intermediates[key] = value - - def add_output(self, key: str, value: Any): - self.outputs[key] = value - - def get_input(self, key: str, default: Any = None) -> Any: - return self.inputs.get(key, default) - - def get_intermediate(self, key: str, default: Any = None) -> Any: - return self.intermediates.get(key, default) - - def get_output(self, key: str, default: Any = None) -> Any: - if key in self.outputs: - return self.outputs[key] - elif key in self.intermediates: - return self.intermediates[key] - else: - return default - - def to_dict(self) -> Dict[str, Any]: - return {**self.__dict__, "inputs": self.inputs, "intermediates": self.intermediates, "outputs": self.outputs} - - def __repr__(self): - def format_value(v): - if hasattr(v, "shape") and hasattr(v, "dtype"): - return f"Tensor(dtype={v.dtype}, shape={v.shape})" - elif isinstance(v, list) and len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): - return f"[Tensor(dtype={v[0].dtype}, shape={v[0].shape}), ...]" - else: - return repr(v) - - inputs = "\n".join(f" {k}: {format_value(v)}" for k, v in self.inputs.items()) - intermediates = "\n".join(f" {k}: {format_value(v)}" for k, v in self.intermediates.items()) - outputs = "\n".join(f" {k}: {format_value(v)}" for k, v in self.outputs.items()) - - return ( - f"PipelineState(\n" - f" inputs={{\n{inputs}\n }},\n" - f" intermediates={{\n{intermediates}\n }},\n" - f" outputs={{\n{outputs}\n }}\n" - f")" - ) - - -@dataclass -class BlockState: - """ - Container for block state data with attribute access and formatted representation. - """ - def __init__(self, **kwargs): - for key, value in kwargs.items(): - setattr(self, key, value) - - def __repr__(self): - def format_value(v): - # Handle tensors directly - if hasattr(v, "shape") and hasattr(v, "dtype"): - return f"Tensor(dtype={v.dtype}, shape={v.shape})" - - # Handle lists of tensors - elif isinstance(v, list): - if len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): - shapes = [t.shape for t in v] - return f"List[{len(v)}] of Tensors with shapes {shapes}" - return repr(v) - - # Handle tuples of tensors - elif isinstance(v, tuple): - if len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): - shapes = [t.shape for t in v] - return f"Tuple[{len(v)}] of Tensors with shapes {shapes}" - return repr(v) - - # Handle dicts with tensor values - elif isinstance(v, dict): - if any(hasattr(val, "shape") and hasattr(val, "dtype") for val in v.values()): - shapes = {k: val.shape for k, val in v.items() if hasattr(val, "shape")} - return f"Dict of Tensors with shapes {shapes}" - return repr(v) - - # Default case - return repr(v) - - attributes = "\n".join(f" {k}: {format_value(v)}" for k, v in self.__dict__.items()) - return f"BlockState(\n{attributes}\n)" - - -@dataclass -class InputParam: - name: str - default: Any = None - required: bool = False - description: str = "" - type_hint: Any = Any - - def __repr__(self): - return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>" - -@dataclass -class OutputParam: - name: str - description: str = "" - type_hint: Any = Any - - def __repr__(self): - return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>" - - - -def format_inputs_short(inputs): - """ - Format input parameters into a string representation, with required params first followed by optional ones. - - Args: - inputs: List of input parameters with 'required' and 'name' attributes, and 'default' for optional params - - Returns: - str: Formatted string of input parameters - """ - required_inputs = [param for param in inputs if param.required] - optional_inputs = [param for param in inputs if not param.required] - - required_str = ", ".join(param.name for param in required_inputs) - optional_str = ", ".join(f"{param.name}={param.default}" for param in optional_inputs) - - inputs_str = required_str - if optional_str: - inputs_str = f"{inputs_str}, {optional_str}" if required_str else optional_str - - return inputs_str - - -def format_intermediates_short(intermediates_inputs: List[InputParam], required_intermediates_inputs: List[str], intermediates_outputs: List[OutputParam]) -> str: - """ - Formats intermediate inputs and outputs of a block into a string representation. - - Args: - block: Pipeline block with potential intermediates - - Returns: - str: Formatted string like "input1, Required(input2) -> output1, output2" - """ - # Handle inputs - input_parts = [] - - for inp in intermediates_inputs: - parts = [] - # Check if input is required - if inp.name in required_intermediates_inputs: - parts.append("Required") - - # Get base name or modified name - name = inp.name - if name in {out.name for out in intermediates_outputs}: - name = f"*{name}" - - # Combine Required() wrapper with possibly starred name - if parts: - input_parts.append(f"Required({name})") - else: - input_parts.append(name) - - # Handle outputs - output_parts = [] - outputs = [out.name for out in intermediates_outputs] - # Only show new outputs if we have inputs - inputs_set = {inp.name for inp in intermediates_inputs} - outputs = [out for out in outputs if out not in inputs_set] - output_parts.extend(outputs) - - # Combine with arrow notation if both inputs and outputs exist - if output_parts: - return f"-> {', '.join(output_parts)}" if not input_parts else f"{', '.join(input_parts)} -> {', '.join(output_parts)}" - elif input_parts: - return ', '.join(input_parts) - return "" - - -def format_params(params: List[Union[InputParam, OutputParam]], header: str = "Args", indent_level: int = 4, max_line_length: int = 115) -> str: - """Format a list of InputParam or OutputParam objects into a readable string representation. - - Args: - params: List of InputParam or OutputParam objects to format - header: Header text to use (e.g. "Args" or "Returns") - indent_level: Number of spaces to indent each parameter line (default: 4) - max_line_length: Maximum length for each line before wrapping (default: 115) - - Returns: - A formatted string representing all parameters - """ - if not params: - return "" - - base_indent = " " * indent_level - param_indent = " " * (indent_level + 4) - desc_indent = " " * (indent_level + 8) - formatted_params = [] - - def get_type_str(type_hint): - if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Union: - types = [t.__name__ if hasattr(t, "__name__") else str(t) for t in type_hint.__args__] - return f"Union[{', '.join(types)}]" - return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint) - - def wrap_text(text: str, indent: str, max_length: int) -> str: - """Wrap text while preserving markdown links and maintaining indentation.""" - words = text.split() - lines = [] - current_line = [] - current_length = 0 - - for word in words: - word_length = len(word) + (1 if current_line else 0) - - if current_line and current_length + word_length > max_length: - lines.append(" ".join(current_line)) - current_line = [word] - current_length = len(word) - else: - current_line.append(word) - current_length += word_length - - if current_line: - lines.append(" ".join(current_line)) - - return f"\n{indent}".join(lines) - - # Add the header - formatted_params.append(f"{base_indent}{header}:") - - for param in params: - # Format parameter name and type - type_str = get_type_str(param.type_hint) if param.type_hint != Any else "" - param_str = f"{param_indent}{param.name} (`{type_str}`" - - # Add optional tag and default value if parameter is an InputParam and optional - if isinstance(param, InputParam): - if not param.required: - param_str += ", *optional*" - if param.default is not None: - param_str += f", defaults to {param.default}" - param_str += "):" - - # Add description on a new line with additional indentation and wrapping - if param.description: - desc = re.sub( - r'\[(.*?)\]\((https?://[^\s\)]+)\)', - r'[\1](\2)', - param.description - ) - wrapped_desc = wrap_text(desc, desc_indent, max_line_length) - param_str += f"\n{desc_indent}{wrapped_desc}" - - formatted_params.append(param_str) - - return "\n\n".join(formatted_params) - - -# Then update the original functions to use this combined version: -def format_input_params(input_params: List[InputParam], indent_level: int = 4, max_line_length: int = 115) -> str: - return format_params(input_params, "Args", indent_level, max_line_length) - - -def format_output_params(output_params: List[OutputParam], indent_level: int = 4, max_line_length: int = 115) -> str: - return format_params(output_params, "Returns", indent_level, max_line_length) - - -def make_doc_string(inputs, intermediates_inputs, intermediates_outputs, final_intermediates_outputs=None, description=""): - """ - Generates a formatted documentation string describing the pipeline block's parameters and structure. - - Returns: - str: A formatted string containing information about call parameters, intermediate inputs/outputs, - and final intermediate outputs. - """ - output = "" - - if description: - desc_lines = description.strip().split('\n') - aligned_desc = '\n'.join(' ' + line for line in desc_lines) - output += aligned_desc + "\n\n" - - output += format_input_params(inputs + intermediates_inputs, indent_level=2) - - # YiYi TODO: refactor to remove this and `outputs` attribute instead - if final_intermediates_outputs: - output += "\n\n" - output += format_output_params(final_intermediates_outputs, indent_level=2) - - if intermediates_outputs: - output += "\n\n------------------------\n" - intermediates_str = format_params(intermediates_outputs, "Intermediates Outputs", indent_level=2) - output += intermediates_str - - elif intermediates_outputs: - output +="\n\n" - output += format_output_params(intermediates_outputs, indent_level=2) - - - return output - - -def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]: - """ - Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if - current default value is None and new default value is not None. Warns if multiple non-None default values - exist for the same input. - - Args: - named_input_lists: List of tuples containing (block_name, input_param_list) pairs - - Returns: - List[InputParam]: Combined list of unique InputParam objects - """ - combined_dict = {} # name -> InputParam - value_sources = {} # name -> block_name - - for block_name, inputs in named_input_lists: - for input_param in inputs: - if input_param.name in combined_dict: - current_param = combined_dict[input_param.name] - if (current_param.default is not None and - input_param.default is not None and - current_param.default != input_param.default): - warnings.warn( - f"Multiple different default values found for input '{input_param.name}': " - f"{current_param.default} (from block '{value_sources[input_param.name]}') and " - f"{input_param.default} (from block '{block_name}'). Using {current_param.default}." - ) - if current_param.default is None and input_param.default is not None: - combined_dict[input_param.name] = input_param - value_sources[input_param.name] = block_name - else: - combined_dict[input_param.name] = input_param - value_sources[input_param.name] = block_name - - return list(combined_dict.values()) - - -def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> List[OutputParam]: - """ - Combines multiple lists of OutputParam objects from different blocks. For duplicate outputs, - keeps the first occurrence of each output name. - - Args: - named_output_lists: List of tuples containing (block_name, output_param_list) pairs - - Returns: - List[OutputParam]: Combined list of unique OutputParam objects - """ - combined_dict = {} # name -> OutputParam - - for block_name, outputs in named_output_lists: - for output_param in outputs: - if output_param.name not in combined_dict: - combined_dict[output_param.name] = output_param - - return list(combined_dict.values()) - - diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index e2d20f8a7ed0..f743f442cc40 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -3020,6 +3020,7 @@ def description(self): # block mapping TEXT2IMAGE_BLOCKS = OrderedDict([ ("text_encoder", StableDiffusionXLTextEncoderStep), + ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), ("input", StableDiffusionXLInputStep), ("set_timesteps", StableDiffusionXLSetTimestepsStep), ("prepare_latents", StableDiffusionXLPrepareLatentsStep), @@ -3030,6 +3031,7 @@ def description(self): IMAGE2IMAGE_BLOCKS = OrderedDict([ ("text_encoder", StableDiffusionXLTextEncoderStep), + ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), ("image_encoder", StableDiffusionXLVaeEncoderStep), ("input", StableDiffusionXLInputStep), ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), @@ -3041,6 +3043,7 @@ def description(self): INPAINT_BLOCKS = OrderedDict([ ("text_encoder", StableDiffusionXLTextEncoderStep), + ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), ("image_encoder", StableDiffusionXLInpaintVaeEncoderStep), ("input", StableDiffusionXLInputStep), ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), @@ -3058,8 +3061,13 @@ def description(self): ("denoise", StableDiffusionXLControlNetUnionDenoiseStep), ]) +IP_ADAPTER_BLOCKS = OrderedDict([ + ("ip_adapter", StableDiffusionXLIPAdapterStep), +]) + AUTO_BLOCKS = OrderedDict([ ("text_encoder", StableDiffusionXLTextEncoderStep), + ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), ("image_encoder", StableDiffusionXLAutoVaeEncoderStep), ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), ("denoise", StableDiffusionXLAutoDenoiseStep), @@ -3078,6 +3086,7 @@ def description(self): "inpaint": INPAINT_BLOCKS, "controlnet": CONTROLNET_BLOCKS, "controlnet_union": CONTROLNET_UNION_BLOCKS, + "ip_adapter": IP_ADAPTER_BLOCKS, "auto": AUTO_BLOCKS } From a8e853b7919e46296d589bddf29de7bc17595fc8 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Thu, 19 Jun 2025 15:34:17 -1000 Subject: [PATCH 065/170] [modular diffusers] more refactor (#11235) * add componentspec and configspec * up * up * move methods to blocks * Modular Diffusers Guiders (#11311) * cfg; slg; pag; sdxl without controlnet * support sdxl controlnet * support controlnet union * update * update * cfg zero* * use unwrap_module for torch compiled modules * remove guider kwargs * remove commented code * remove old guider * fix slg bug * remove debug print * autoguidance * smoothed energy guidance * add note about seg * tangential cfg * cfg plus plus * support cfgpp in ddim * apply review suggestions * refactor * rename enable/disable * remove cfg++ for now * rename do_classifier_free_guidance->prepare_unconditional_embeds * remove unused * [modular diffusers] introducing ModularLoader (#11462) * cfg; slg; pag; sdxl without controlnet --------- Co-authored-by: Aryan * make loader optional * remove lora step and ip-adapter step -> no longer needed * rename pipeline -> components, data -> block_state * seperate controlnet step into input + denoise * refactor controlnet union * reefactor pipeline/block states so that it can dynamically accept kwargs * remove controlnet union denoise step, refactor & reuse controlnet denoisee step to accept aditional contrlnet kwargs * allow input_fields as input & update message * update input formating, consider kwarggs_type inputs with no name, e/g *_controlnet_kwargs * refactor the denoiseestep using LoopSequential! also add a new file for denoise step * change warning to debug * fix get_execusion blocks with loopsequential * fix auto denoise so all tests pass * update imports on guiders * remove modular reelated change from pipelines folder * made a modular_pipelines folder! * update __init__ * add notes * add block state will also make sure modifed intermediates_inputs will be updated * move block mappings to its own file * make inputs truly immutable, remove the output logic in sequential pipeline, and update so that intermediates_outputs are only new variables * decode block, if skip decoding do not need to update latent * fix imports * fix import * fix more * remove the output step * make generator intermediates (it is mutable) * after_denoise -> decoders * add a to-do for guider cconfig mixin * refactor component spec: replace create/create_from_pretrained/create_from_config to just create and load method * refactor modular loader: 1. load only load (pretrained components only if not specific names) 2. update acceept create spec 3. move the updte _componeent_spec logic outside register_components to each method that create/update the component: __init__/update/load * update components manager * up * [WIP] Modular Diffusers support custom code/pipeline blocks (#11539) * update * update * remove the duplicated components_manager file I forgot to deletee * fix import in block mapping * add a to-do for modular loader * prepare_latents_img2img pipeline method -> function, maybe do the same for others? * update input for loop blocks, do not need to include intermediate * solve merge conflict: manually add back the remote code change to modular_pipeline * add node_utils * modular node! * add * refator based on dhruv's feedbacks * update doc format for kwargs_type * up * updatee modular_pipeline.from_pretrained, modular_repo ->pretrained_model_name_or_path * save_pretrained for serializing config. (#11603) * save_pretrained for serializing config. * remove pushtohub * diffusers-cli rough --------- Co-authored-by: YiYi Xu --------- Co-authored-by: Aryan Co-authored-by: Dhruv Nair Co-authored-by: Sayak Paul --- src/diffusers/__init__.py | 79 +- src/diffusers/commands/custom_blocks.py | 133 + src/diffusers/commands/diffusers_cli.py | 2 + src/diffusers/guider.py | 745 ---- src/diffusers/guiders/__init__.py | 29 + .../guiders/adaptive_projected_guidance.py | 184 + src/diffusers/guiders/auto_guidance.py | 177 + .../guiders/classifier_free_guidance.py | 132 + .../classifier_free_zero_star_guidance.py | 148 + .../guiders/entropy_rectifying_guidance.py | 0 src/diffusers/guiders/guider_utils.py | 215 + src/diffusers/guiders/skip_layer_guidance.py | 251 ++ .../guiders/smoothed_energy_guidance.py | 244 + .../tangential_classifier_free_guidance.py | 137 + src/diffusers/hooks/__init__.py | 2 + src/diffusers/hooks/_common.py | 43 + src/diffusers/hooks/_helpers.py | 271 ++ src/diffusers/hooks/layer_skip.py | 231 + .../hooks/smoothed_energy_guidance_utils.py | 158 + src/diffusers/modular_pipelines/__init__.py | 84 + .../components_manager.py | 495 ++- .../modular_pipelines/modular_pipeline.py | 2247 ++++++++++ .../modular_pipeline_utils.py | 616 +++ src/diffusers/modular_pipelines/node_utils.py | 519 +++ .../stable_diffusion_xl/__init__.py | 51 + .../stable_diffusion_xl/before_denoise.py | 1764 ++++++++ .../stable_diffusion_xl/decoders.py | 215 + .../stable_diffusion_xl/denoise.py | 1392 ++++++ .../stable_diffusion_xl/encoders.py | 858 ++++ .../stable_diffusion_xl/modular_loader.py | 174 + .../modular_pipeline_block_mappings.py | 126 + .../modular_pipeline_presets.py | 43 + src/diffusers/pipelines/__init__.py | 6 - src/diffusers/pipelines/modular_pipeline.py | 1704 ------- .../pipelines/pipeline_loading_utils.py | 21 +- src/diffusers/pipelines/pipeline_utils.py | 3 +- .../pipelines/stable_diffusion_xl/__init__.py | 24 - .../pipeline_stable_diffusion_xl_modular.py | 3909 ----------------- src/diffusers/utils/dummy_pt_objects.py | 2 +- .../dummy_torch_and_transformers_objects.py | 2 +- src/diffusers/utils/dynamic_modules_utils.py | 85 +- src/diffusers/utils/torch_utils.py | 5 + 42 files changed, 11037 insertions(+), 6489 deletions(-) create mode 100644 src/diffusers/commands/custom_blocks.py delete mode 100644 src/diffusers/guider.py create mode 100644 src/diffusers/guiders/__init__.py create mode 100644 src/diffusers/guiders/adaptive_projected_guidance.py create mode 100644 src/diffusers/guiders/auto_guidance.py create mode 100644 src/diffusers/guiders/classifier_free_guidance.py create mode 100644 src/diffusers/guiders/classifier_free_zero_star_guidance.py create mode 100644 src/diffusers/guiders/entropy_rectifying_guidance.py create mode 100644 src/diffusers/guiders/guider_utils.py create mode 100644 src/diffusers/guiders/skip_layer_guidance.py create mode 100644 src/diffusers/guiders/smoothed_energy_guidance.py create mode 100644 src/diffusers/guiders/tangential_classifier_free_guidance.py create mode 100644 src/diffusers/hooks/_common.py create mode 100644 src/diffusers/hooks/_helpers.py create mode 100644 src/diffusers/hooks/layer_skip.py create mode 100644 src/diffusers/hooks/smoothed_energy_guidance_utils.py create mode 100644 src/diffusers/modular_pipelines/__init__.py rename src/diffusers/{pipelines => modular_pipelines}/components_manager.py (51%) create mode 100644 src/diffusers/modular_pipelines/modular_pipeline.py create mode 100644 src/diffusers/modular_pipelines/modular_pipeline_utils.py create mode 100644 src/diffusers/modular_pipelines/node_utils.py create mode 100644 src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py create mode 100644 src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py create mode 100644 src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py create mode 100644 src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py create mode 100644 src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py create mode 100644 src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py create mode 100644 src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_block_mappings.py create mode 100644 src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py delete mode 100644 src/diffusers/pipelines/modular_pipeline.py delete mode 100644 src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index ae8768ae9f72..d78b759c85c1 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -34,10 +34,12 @@ _import_structure = { "configuration_utils": ["ConfigMixin"], + "guiders": [], "hooks": [], "loaders": ["FromOriginalModelMixin"], "models": [], "pipelines": [], + "modular_pipelines": [], "quantizers.quantization_config": [], "schedulers": [], "utils": [ @@ -130,12 +132,26 @@ _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")] else: + _import_structure["guiders"].extend( + [ + "AdaptiveProjectedGuidance", + "AutoGuidance", + "ClassifierFreeGuidance", + "ClassifierFreeZeroStarGuidance", + "SkipLayerGuidance", + "SmoothedEnergyGuidance", + "TangentialClassifierFreeGuidance", + ] + ) _import_structure["hooks"].extend( [ "FasterCacheConfig", "HookRegistry", "PyramidAttentionBroadcastConfig", + "LayerSkipConfig", + "SmoothedEnergyGuidanceConfig", "apply_faster_cache", + "apply_layer_skip", "apply_pyramid_attention_broadcast", ] ) @@ -239,13 +255,21 @@ "KarrasVePipeline", "LDMPipeline", "LDMSuperResolutionPipeline", - "ModularPipeline", "PNDMPipeline", "RePaintPipeline", "ScoreSdeVePipeline", "StableDiffusionMixin", ] ) + _import_structure["modular_pipelines"].extend( + [ + "ModularLoader", + "ModularPipeline", + "ModularPipelineBlocks", + "ComponentSpec", + "ComponentsManager", + ] + ) _import_structure["quantizers"] = ["DiffusersQuantizer"] _import_structure["schedulers"].extend( [ @@ -494,12 +518,10 @@ "StableDiffusionXLImg2ImgPipeline", "StableDiffusionXLInpaintPipeline", "StableDiffusionXLInstructPix2PixPipeline", - "StableDiffusionXLModularPipeline", "StableDiffusionXLPAGImg2ImgPipeline", "StableDiffusionXLPAGInpaintPipeline", "StableDiffusionXLPAGPipeline", "StableDiffusionXLPipeline", - "StableDiffusionXLAutoPipeline", "StableUnCLIPImg2ImgPipeline", "StableUnCLIPPipeline", "StableVideoDiffusionPipeline", @@ -526,6 +548,24 @@ ] ) + +try: + if not (is_torch_available() and is_transformers_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils import dummy_torch_and_transformers_objects # noqa F403 + + _import_structure["utils.dummy_torch_and_transformers_objects"] = [ + name for name in dir(dummy_torch_and_transformers_objects) if not name.startswith("_") + ] + +else: + _import_structure["modular_pipelines"].extend( + [ + "StableDiffusionXLAutoPipeline", + "StableDiffusionXLModularLoader", + ] + ) try: if not (is_torch_available() and is_transformers_available() and is_opencv_available()): raise OptionalDependencyNotAvailable() @@ -731,10 +771,22 @@ except OptionalDependencyNotAvailable: from .utils.dummy_pt_objects import * # noqa F403 else: + from .guiders import ( + AdaptiveProjectedGuidance, + AutoGuidance, + ClassifierFreeGuidance, + ClassifierFreeZeroStarGuidance, + SkipLayerGuidance, + SmoothedEnergyGuidance, + TangentialClassifierFreeGuidance, + ) from .hooks import ( FasterCacheConfig, HookRegistry, + LayerSkipConfig, PyramidAttentionBroadcastConfig, + SmoothedEnergyGuidanceConfig, + apply_layer_skip, apply_faster_cache, apply_pyramid_attention_broadcast, ) @@ -837,12 +889,18 @@ KarrasVePipeline, LDMPipeline, LDMSuperResolutionPipeline, - ModularPipeline, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline, StableDiffusionMixin, ) + from .modular_pipelines import ( + ModularLoader, + ModularPipeline, + ModularPipelineBlocks, + ComponentSpec, + ComponentsManager, + ) from .quantizers import DiffusersQuantizer from .schedulers import ( AmusedScheduler, @@ -1070,12 +1128,10 @@ StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLInstructPix2PixPipeline, - StableDiffusionXLModularPipeline, StableDiffusionXLPAGImg2ImgPipeline, StableDiffusionXLPAGInpaintPipeline, StableDiffusionXLPAGPipeline, StableDiffusionXLPipeline, - StableDiffusionXLAutoPipeline, StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, StableVideoDiffusionPipeline, @@ -1100,7 +1156,16 @@ WuerstchenDecoderPipeline, WuerstchenPriorPipeline, ) - + try: + if not (is_torch_available() and is_transformers_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .modular_pipelines import ( + StableDiffusionXLAutoPipeline, + StableDiffusionXLModularLoader, + ) try: if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): raise OptionalDependencyNotAvailable() diff --git a/src/diffusers/commands/custom_blocks.py b/src/diffusers/commands/custom_blocks.py new file mode 100644 index 000000000000..d2f2de3a8f9a --- /dev/null +++ b/src/diffusers/commands/custom_blocks.py @@ -0,0 +1,133 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage example: + TODO +""" + +import ast +from argparse import ArgumentParser, Namespace +from pathlib import Path +import importlib.util +import os +from ..utils import logging +from . import BaseDiffusersCLICommand + + +EXPECTED_PARENT_CLASSES = ["PipelineBlock"] +CONFIG = "config.json" + +def conversion_command_factory(args: Namespace): + return CustomBlocksCommand(args.block_module_name, args.block_class_name) + + +class CustomBlocksCommand(BaseDiffusersCLICommand): + @staticmethod + def register_subcommand(parser: ArgumentParser): + conversion_parser = parser.add_parser("custom_blocks") + conversion_parser.add_argument( + "--block_module_name", + type=str, + default="block.py", + help="Module filename in which the custom block will be implemented.", + ) + conversion_parser.add_argument( + "--block_class_name", type=str, default=None, help="Name of the custom block. If provided None, we will try to infer it." + ) + conversion_parser.set_defaults(func=conversion_command_factory) + + def __init__(self, block_module_name: str = "block.py", block_class_name: str = None): + self.logger = logging.get_logger("diffusers-cli/custom_blocks") + self.block_module_name = Path(block_module_name) + self.block_class_name = block_class_name + + def run(self): + # determine the block to be saved. + out = self._get_class_names(self.block_module_name) + classes_found = list({cls for cls, _ in out}) + + if self.block_class_name is not None: + child_class, parent_class = self._choose_block(out, self.block_class_name) + if child_class is None and parent_class is None: + raise ValueError( + "`block_class_name` could not be retrieved. Available classes from " + f"{self.block_module_name}:\n{classes_found}" + ) + else: + self.logger.info( + f"Found classes: {classes_found} will be using {classes_found[0]}. " + "If this needs to be changed, re-run the command specifying `block_class_name`." + ) + child_class, parent_class = out[0][0], out[0][1] + + # dynamically get the custom block and initialize it to call `save_pretrained` in the current directory. + # the user is responsible for running it, so I guess that is safe? + module_name = f"__dynamic__{self.block_module_name.stem}" + spec = importlib.util.spec_from_file_location(module_name, str(self.block_module_name)) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + getattr(module, child_class)().save_pretrained(os.getcwd()) + + # or, we could create it manually. + # automap = self._create_automap(parent_class=parent_class, child_class=child_class) + # with open(CONFIG, "w") as f: + # json.dump(automap, f) + with open("requirements.txt", "w") as f: + f.write("") + + def _choose_block(self, candidates, chosen=None): + for cls, base in candidates: + if cls == chosen: + return cls, base + return None, None + + def _get_class_names(self, file_path): + source = file_path.read_text(encoding="utf-8") + try: + tree = ast.parse(source, filename=file_path) + except SyntaxError as e: + raise ValueError(f"Could not parse {file_path!r}: {e}") from e + + results: list[tuple[str, str]] = [] + for node in tree.body: + if not isinstance(node, ast.ClassDef): + continue + + # extract all base names for this class + base_names = [ + bname for b in node.bases + if (bname := self._get_base_name(b)) is not None + ] + + # for each allowed base that appears in the class's bases, emit a tuple + for allowed in EXPECTED_PARENT_CLASSES: + if allowed in base_names: + results.append((node.name, allowed)) + + return results + + def _get_base_name(self, node: ast.expr): + if isinstance(node, ast.Name): + return node.id + elif isinstance(node, ast.Attribute): + val = self._get_base_name(node.value) + return f"{val}.{node.attr}" if val else node.attr + return None + + def _create_automap(self, parent_class, child_class): + module = str(self.block_module_name).replace(".py", "").rsplit(".", 1)[-1] + auto_map = {f"{parent_class}": f"{module}.{child_class}"} + return {"auto_map": auto_map} + diff --git a/src/diffusers/commands/diffusers_cli.py b/src/diffusers/commands/diffusers_cli.py index f582c3bcd0df..cdc7dad166f0 100644 --- a/src/diffusers/commands/diffusers_cli.py +++ b/src/diffusers/commands/diffusers_cli.py @@ -17,6 +17,7 @@ from .env import EnvironmentCommand from .fp16_safetensors import FP16SafetensorsCommand +from .custom_blocks import CustomBlocksCommand def main(): @@ -26,6 +27,7 @@ def main(): # Register commands EnvironmentCommand.register_subcommand(commands_parser) FP16SafetensorsCommand.register_subcommand(commands_parser) + CustomBlocksCommand.register_subcommand(commands_parser) # Let's go args = parser.parse_args() diff --git a/src/diffusers/guider.py b/src/diffusers/guider.py deleted file mode 100644 index 7445b7ba97af..000000000000 --- a/src/diffusers/guider.py +++ /dev/null @@ -1,745 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import re -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch -import torch.nn as nn - -from .models.attention_processor import ( - Attention, - AttentionProcessor, - PAGCFGIdentitySelfAttnProcessor2_0, - PAGIdentitySelfAttnProcessor2_0, -) -from .utils import logging - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg -def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): - r""" - Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on - Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are - Flawed](https://arxiv.org/pdf/2305.08891.pdf). - - Args: - noise_cfg (`torch.Tensor`): - The predicted noise tensor for the guided diffusion process. - noise_pred_text (`torch.Tensor`): - The predicted noise tensor for the text-guided diffusion process. - guidance_rescale (`float`, *optional*, defaults to 0.0): - A rescale factor applied to the noise predictions. - - Returns: - noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. - """ - std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) - std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) - # rescale the results from guidance (fixes overexposure) - noise_pred_rescaled = noise_cfg * (std_text / std_cfg) - # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images - noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg - return noise_cfg - - -class CFGGuider: - """ - This class is used to guide the pipeline with CFG (Classifier-Free Guidance). - """ - - @property - def do_classifier_free_guidance(self): - return self._guidance_scale > 1.0 and not self._disable_guidance - - @property - def guidance_rescale(self): - return self._guidance_rescale - - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def batch_size(self): - return self._batch_size - - def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]): - # a flag to disable CFG, e.g. we disable it for LCM and use a guidance scale embedding instead - disable_guidance = guider_kwargs.get("disable_guidance", False) - guidance_scale = guider_kwargs.get("guidance_scale", None) - if guidance_scale is None: - raise ValueError("guidance_scale is not provided in guider_kwargs") - guidance_rescale = guider_kwargs.get("guidance_rescale", 0.0) - batch_size = guider_kwargs.get("batch_size", None) - if batch_size is None: - raise ValueError("batch_size is not provided in guider_kwargs") - self._guidance_scale = guidance_scale - self._guidance_rescale = guidance_rescale - self._batch_size = batch_size - self._disable_guidance = disable_guidance - - def reset_guider(self, pipeline): - pass - - def maybe_update_guider(self, pipeline, timestep): - pass - - def maybe_update_input(self, pipeline, cond_input): - pass - - def _maybe_split_prepared_input(self, cond): - """ - Process and potentially split the conditional input for Classifier-Free Guidance (CFG). - - This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`). - It determines whether to split the input based on its batch size relative to the expected batch size. - - Args: - cond (torch.Tensor): The conditional input tensor to process. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - The negative conditional input (uncond_input) - - The positive conditional input (cond_input) - """ - if cond.shape[0] == self.batch_size * 2: - neg_cond = cond[0 : self.batch_size] - cond = cond[self.batch_size :] - return neg_cond, cond - elif cond.shape[0] == self.batch_size: - return cond, cond - else: - raise ValueError(f"Unsupported input shape: {cond.shape}") - - def _is_prepared_input(self, cond): - """ - Check if the input is already prepared for Classifier-Free Guidance (CFG). - - Args: - cond (torch.Tensor): The conditional input tensor to check. - - Returns: - bool: True if the input is already prepared, False otherwise. - """ - cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond - - return cond_tensor.shape[0] == self.batch_size * 2 - - def prepare_input( - self, - cond_input: Union[torch.Tensor, List[torch.Tensor]], - negative_cond_input: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, - ) -> Union[torch.Tensor, List[torch.Tensor]]: - """ - Prepare the input for CFG. - - Args: - cond_input (Union[torch.Tensor, List[torch.Tensor]]): - The conditional input. It can be a single tensor or a - list of tensors. It must have the same length as `negative_cond_input`. - negative_cond_input (Union[torch.Tensor, List[torch.Tensor]]): The negative conditional input. It can be a - single tensor or a list of tensors. It must have the same length as `cond_input`. - - Returns: - Union[torch.Tensor, List[torch.Tensor]]: The prepared input. - """ - - # we check if cond_input already has CFG applied, and split if it is the case. - if self._is_prepared_input(cond_input) and self.do_classifier_free_guidance: - return cond_input - - if self._is_prepared_input(cond_input) and not self.do_classifier_free_guidance: - if isinstance(cond_input, list): - negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input]) - else: - negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input) - - if not self._is_prepared_input(cond_input) and self.do_classifier_free_guidance and negative_cond_input is None: - raise ValueError( - "`negative_cond_input` is required when cond_input does not already contains negative conditional input" - ) - - if isinstance(cond_input, (list, tuple)): - if not self.do_classifier_free_guidance: - return cond_input - - if len(negative_cond_input) != len(cond_input): - raise ValueError("The length of negative_cond_input and cond_input must be the same.") - prepared_input = [] - for neg_cond, cond in zip(negative_cond_input, cond_input): - if neg_cond.shape[0] != cond.shape[0]: - raise ValueError("The batch size of negative_cond_input and cond_input must be the same.") - prepared_input.append(torch.cat([neg_cond, cond], dim=0)) - return prepared_input - - elif isinstance(cond_input, torch.Tensor): - if not self.do_classifier_free_guidance: - return cond_input - else: - return torch.cat([negative_cond_input, cond_input], dim=0) - - else: - raise ValueError(f"Unsupported input type: {type(cond_input)}") - - def apply_guidance( - self, - model_output: torch.Tensor, - timestep: int = None, - latents: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if not self.do_classifier_free_guidance: - return model_output - - noise_pred_uncond, noise_pred_text = model_output.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - - if self.guidance_rescale > 0.0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) - return noise_pred - - -class PAGGuider: - """ - This class is used to guide the pipeline with CFG (Classifier-Free Guidance). - """ - - def __init__( - self, - pag_applied_layers: Union[str, List[str]], - pag_attn_processors: Tuple[AttentionProcessor, AttentionProcessor] = ( - PAGCFGIdentitySelfAttnProcessor2_0(), - PAGIdentitySelfAttnProcessor2_0(), - ), - ): - r""" - Set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid. - - Args: - pag_applied_layers (`str` or `List[str]`): - One or more strings identifying the layer names, or a simple regex for matching multiple layers, where - PAG is to be applied. A few ways of expected usage are as follows: - - Single layers specified as - "blocks.{layer_index}" - - Multiple layers as a list - ["blocks.{layers_index_1}", "blocks.{layer_index_2}", ...] - - Multiple layers as a block name - "mid" - - Multiple layers as regex - "blocks.({layer_index_1}|{layer_index_2})" - pag_attn_processors: - (`Tuple[AttentionProcessor, AttentionProcessor]`, defaults to `(PAGCFGIdentitySelfAttnProcessor2_0(), - PAGIdentitySelfAttnProcessor2_0())`): A tuple of two attention processors. The first attention - processor is for PAG with Classifier-free guidance enabled (conditional and unconditional). The second - attention processor is for PAG with CFG disabled (unconditional only). - """ - - if not isinstance(pag_applied_layers, list): - pag_applied_layers = [pag_applied_layers] - if pag_attn_processors is not None: - if not isinstance(pag_attn_processors, tuple) or len(pag_attn_processors) != 2: - raise ValueError("Expected a tuple of two attention processors") - - for i in range(len(pag_applied_layers)): - if not isinstance(pag_applied_layers[i], str): - raise ValueError( - f"Expected either a string or a list of string but got type {type(pag_applied_layers[i])}" - ) - - self.pag_applied_layers = pag_applied_layers - self._pag_attn_processors = pag_attn_processors - - def _set_pag_attn_processor(self, model, pag_applied_layers, do_classifier_free_guidance): - r""" - Set the attention processor for the PAG layers. - """ - pag_attn_processors = self._pag_attn_processors - pag_attn_proc = pag_attn_processors[0] if do_classifier_free_guidance else pag_attn_processors[1] - - def is_self_attn(module: nn.Module) -> bool: - r""" - Check if the module is self-attention module based on its name. - """ - return isinstance(module, Attention) and not module.is_cross_attention - - def is_fake_integral_match(layer_id, name): - layer_id = layer_id.split(".")[-1] - name = name.split(".")[-1] - return layer_id.isnumeric() and name.isnumeric() and layer_id == name - - for layer_id in pag_applied_layers: - # for each PAG layer input, we find corresponding self-attention layers in the unet model - target_modules = [] - - for name, module in model.named_modules(): - # Identify the following simple cases: - # (1) Self Attention layer existing - # (2) Whether the module name matches pag layer id even partially - # (3) Make sure it's not a fake integral match if the layer_id ends with a number - # For example, blocks.1, blocks.10 should be differentiable if layer_id="blocks.1" - if ( - is_self_attn(module) - and re.search(layer_id, name) is not None - and not is_fake_integral_match(layer_id, name) - ): - logger.debug(f"Applying PAG to layer: {name}") - target_modules.append(module) - - if len(target_modules) == 0: - raise ValueError(f"Cannot find PAG layer to set attention processor for: {layer_id}") - - for module in target_modules: - module.processor = pag_attn_proc - - @property - def do_classifier_free_guidance(self): - return self._guidance_scale > 1 and not self._disable_guidance - - @property - def do_perturbed_attention_guidance(self): - return self._pag_scale > 0 and not self._disable_guidance - - @property - def do_pag_adaptive_scaling(self): - return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and not self._disable_guidance - - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def guidance_rescale(self): - return self._guidance_rescale - - @property - def batch_size(self): - return self._batch_size - - @property - def pag_scale(self): - return self._pag_scale - - @property - def pag_adaptive_scale(self): - return self._pag_adaptive_scale - - def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]): - pag_scale = guider_kwargs.get("pag_scale", 3.0) - pag_adaptive_scale = guider_kwargs.get("pag_adaptive_scale", 0.0) - - batch_size = guider_kwargs.get("batch_size", None) - if batch_size is None: - raise ValueError("batch_size is a required argument for PAGGuider") - - guidance_scale = guider_kwargs.get("guidance_scale", None) - guidance_rescale = guider_kwargs.get("guidance_rescale", 0.0) - disable_guidance = guider_kwargs.get("disable_guidance", False) - - if guidance_scale is None: - raise ValueError("guidance_scale is a required argument for PAGGuider") - - self._pag_scale = pag_scale - self._pag_adaptive_scale = pag_adaptive_scale - self._guidance_scale = guidance_scale - self._disable_guidance = disable_guidance - self._guidance_rescale = guidance_rescale - self._batch_size = batch_size - if not hasattr(pipeline, "original_attn_proc") or pipeline.original_attn_proc is None: - pipeline.original_attn_proc = pipeline.unet.attn_processors - self._set_pag_attn_processor( - model=pipeline.unet if hasattr(pipeline, "unet") else pipeline.transformer, - pag_applied_layers=self.pag_applied_layers, - do_classifier_free_guidance=self.do_classifier_free_guidance, - ) - - def reset_guider(self, pipeline): - if ( - self.do_perturbed_attention_guidance - and hasattr(pipeline, "original_attn_proc") - and pipeline.original_attn_proc is not None - ): - pipeline.unet.set_attn_processor(pipeline.original_attn_proc) - pipeline.original_attn_proc = None - - def maybe_update_guider(self, pipeline, timestep): - pass - - def maybe_update_input(self, pipeline, cond_input): - pass - - def _is_prepared_input(self, cond): - """ - Check if the input is already prepared for Perturbed Attention Guidance (PAG). - - Args: - cond (torch.Tensor): The conditional input tensor to check. - - Returns: - bool: True if the input is already prepared, False otherwise. - """ - cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond - - return cond_tensor.shape[0] == self.batch_size * 3 - - def _maybe_split_prepared_input(self, cond): - """ - Process and potentially split the conditional input for Classifier-Free Guidance (CFG). - - This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`). - It determines whether to split the input based on its batch size relative to the expected batch size. - - Args: - cond (torch.Tensor): The conditional input tensor to process. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - The negative conditional input (uncond_input) - - The positive conditional input (cond_input) - """ - if cond.shape[0] == self.batch_size * 3: - neg_cond = cond[0 : self.batch_size] - cond = cond[self.batch_size : self.batch_size * 2] - return neg_cond, cond - elif cond.shape[0] == self.batch_size: - return cond, cond - else: - raise ValueError(f"Unsupported input shape: {cond.shape}") - - def prepare_input( - self, - cond_input: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]], - negative_cond_input: Optional[Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]] = None, - ) -> Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]: - """ - Prepare the input for CFG. - - Args: - cond_input (Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]): - The conditional input. It can be a single tensor or a - list of tensors. It must have the same length as `negative_cond_input`. - negative_cond_input (Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]): - The negative conditional input. It can be a single tensor or a list of tensors. It must have the same - length as `cond_input`. - - Returns: - Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]: The prepared input. - """ - - # we check if cond_input already has CFG applied, and split if it is the case. - - if self._is_prepared_input(cond_input) and self.do_perturbed_attention_guidance: - return cond_input - - if self._is_prepared_input(cond_input) and not self.do_perturbed_attention_guidance: - if isinstance(cond_input, list): - negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input]) - else: - negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input) - - if not self._is_prepared_input(cond_input) and self.do_perturbed_attention_guidance and negative_cond_input is None: - raise ValueError( - "`negative_cond_input` is required when cond_input does not already contains negative conditional input" - ) - - if isinstance(cond_input, (list, tuple)): - if not self.do_perturbed_attention_guidance: - return cond_input - - if len(negative_cond_input) != len(cond_input): - raise ValueError("The length of negative_cond_input and cond_input must be the same.") - - prepared_input = [] - for neg_cond, cond in zip(negative_cond_input, cond_input): - if neg_cond.shape[0] != cond.shape[0]: - raise ValueError("The batch size of negative_cond_input and cond_input must be the same.") - - cond = torch.cat([cond] * 2, dim=0) - if self.do_classifier_free_guidance: - prepared_input.append(torch.cat([neg_cond, cond], dim=0)) - else: - prepared_input.append(cond) - - return prepared_input - - elif isinstance(cond_input, torch.Tensor): - if not self.do_perturbed_attention_guidance: - return cond_input - - cond_input = torch.cat([cond_input] * 2, dim=0) - if self.do_classifier_free_guidance: - return torch.cat([negative_cond_input, cond_input], dim=0) - else: - return cond_input - - else: - raise ValueError(f"Unsupported input type: {type(negative_cond_input)} and {type(cond_input)}") - - def apply_guidance( - self, - model_output: torch.Tensor, - timestep: int, - latents: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if not self.do_perturbed_attention_guidance: - return model_output - - if self.do_pag_adaptive_scaling: - pag_scale = max(self._pag_scale - self._pag_adaptive_scale * (1000 - timestep), 0) - else: - pag_scale = self._pag_scale - - if self.do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text, noise_pred_perturb = model_output.chunk(3) - noise_pred = ( - noise_pred_uncond - + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - + pag_scale * (noise_pred_text - noise_pred_perturb) - ) - else: - noise_pred_text, noise_pred_perturb = model_output.chunk(2) - noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb) - - if self.guidance_rescale > 0.0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) - - return noise_pred - - -class MomentumBuffer: - def __init__(self, momentum: float): - self.momentum = momentum - self.running_average = 0 - - def update(self, update_value: torch.Tensor): - new_average = self.momentum * self.running_average - self.running_average = update_value + new_average - - -class APGGuider: - """ - This class is used to guide the pipeline with APG (Adaptive Projected Guidance). - """ - - def normalized_guidance( - self, - pred_cond: torch.Tensor, - pred_uncond: torch.Tensor, - guidance_scale: float, - momentum_buffer: MomentumBuffer = None, - norm_threshold: float = 0.0, - eta: float = 1.0, - ): - """ - Based on the findings of [Eliminating Oversaturation and Artifacts of High Guidance Scales in Diffusion - Models](https://arxiv.org/pdf/2410.02416) - """ - diff = pred_cond - pred_uncond - if momentum_buffer is not None: - momentum_buffer.update(diff) - diff = momentum_buffer.running_average - if norm_threshold > 0: - ones = torch.ones_like(diff) - diff_norm = diff.norm(p=2, dim=[-1, -2, -3], keepdim=True) - scale_factor = torch.minimum(ones, norm_threshold / diff_norm) - diff = diff * scale_factor - v0, v1 = diff.double(), pred_cond.double() - v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3]) - v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1 - v0_orthogonal = v0 - v0_parallel - diff_parallel, diff_orthogonal = v0_parallel.to(diff.dtype), v0_orthogonal.to(diff.dtype) - normalized_update = diff_orthogonal + eta * diff_parallel - pred_guided = pred_cond + (guidance_scale - 1) * normalized_update - return pred_guided - - @property - def adaptive_projected_guidance_momentum(self): - return self._adaptive_projected_guidance_momentum - - @property - def adaptive_projected_guidance_rescale_factor(self): - return self._adaptive_projected_guidance_rescale_factor - - @property - def do_classifier_free_guidance(self): - return self._guidance_scale > 1.0 and not self._disable_guidance - - @property - def guidance_rescale(self): - return self._guidance_rescale - - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def batch_size(self): - return self._batch_size - - def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]): - disable_guidance = guider_kwargs.get("disable_guidance", False) - guidance_scale = guider_kwargs.get("guidance_scale", None) - if guidance_scale is None: - raise ValueError("guidance_scale is not provided in guider_kwargs") - adaptive_projected_guidance_momentum = guider_kwargs.get("adaptive_projected_guidance_momentum", None) - adaptive_projected_guidance_rescale_factor = guider_kwargs.get( - "adaptive_projected_guidance_rescale_factor", 15.0 - ) - guidance_rescale = guider_kwargs.get("guidance_rescale", 0.0) - batch_size = guider_kwargs.get("batch_size", None) - if batch_size is None: - raise ValueError("batch_size is not provided in guider_kwargs") - self._adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum - self._adaptive_projected_guidance_rescale_factor = adaptive_projected_guidance_rescale_factor - self._guidance_scale = guidance_scale - self._guidance_rescale = guidance_rescale - self._batch_size = batch_size - self._disable_guidance = disable_guidance - if adaptive_projected_guidance_momentum is not None: - self.momentum_buffer = MomentumBuffer(adaptive_projected_guidance_momentum) - else: - self.momentum_buffer = None - self.scheduler = pipeline.scheduler - - def reset_guider(self, pipeline): - pass - - def maybe_update_guider(self, pipeline, timestep): - pass - - def maybe_update_input(self, pipeline, cond_input): - pass - - def _maybe_split_prepared_input(self, cond): - """ - Process and potentially split the conditional input for Classifier-Free Guidance (CFG). - - This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`). - It determines whether to split the input based on its batch size relative to the expected batch size. - - Args: - cond (torch.Tensor): The conditional input tensor to process. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - The negative conditional input (uncond_input) - - The positive conditional input (cond_input) - """ - if cond.shape[0] == self.batch_size * 2: - neg_cond = cond[0 : self.batch_size] - cond = cond[self.batch_size :] - return neg_cond, cond - elif cond.shape[0] == self.batch_size: - return cond, cond - else: - raise ValueError(f"Unsupported input shape: {cond.shape}") - - def _is_prepared_input(self, cond): - """ - Check if the input is already prepared for Classifier-Free Guidance (CFG). - - Args: - cond (torch.Tensor): The conditional input tensor to check. - - Returns: - bool: True if the input is already prepared, False otherwise. - """ - cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond - - return cond_tensor.shape[0] == self.batch_size * 2 - - def prepare_input( - self, - cond_input: Union[torch.Tensor, List[torch.Tensor]], - negative_cond_input: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, - ) -> Union[torch.Tensor, List[torch.Tensor]]: - """ - Prepare the input for CFG. - - Args: - cond_input (Union[torch.Tensor, List[torch.Tensor]]): - The conditional input. It can be a single tensor or a - list of tensors. It must have the same length as `negative_cond_input`. - negative_cond_input (Union[torch.Tensor, List[torch.Tensor]]): The negative conditional input. It can be a - single tensor or a list of tensors. It must have the same length as `cond_input`. - - Returns: - Union[torch.Tensor, List[torch.Tensor]]: The prepared input. - """ - - # we check if cond_input already has CFG applied, and split if it is the case. - if self._is_prepared_input(cond_input) and self.do_classifier_free_guidance: - return cond_input - - if self._is_prepared_input(cond_input) and not self.do_classifier_free_guidance: - if isinstance(cond_input, list): - negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input]) - else: - negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input) - - if not self._is_prepared_input(cond_input) and self.do_classifier_free_guidance and negative_cond_input is None: - raise ValueError( - "`negative_cond_input` is required when cond_input does not already contains negative conditional input" - ) - - if isinstance(cond_input, (list, tuple)): - if not self.do_classifier_free_guidance: - return cond_input - - if len(negative_cond_input) != len(cond_input): - raise ValueError("The length of negative_cond_input and cond_input must be the same.") - prepared_input = [] - for neg_cond, cond in zip(negative_cond_input, cond_input): - if neg_cond.shape[0] != cond.shape[0]: - raise ValueError("The batch size of negative_cond_input and cond_input must be the same.") - prepared_input.append(torch.cat([neg_cond, cond], dim=0)) - return prepared_input - - elif isinstance(cond_input, torch.Tensor): - if not self.do_classifier_free_guidance: - return cond_input - else: - return torch.cat([negative_cond_input, cond_input], dim=0) - - else: - raise ValueError(f"Unsupported input type: {type(cond_input)}") - - def apply_guidance( - self, - model_output: torch.Tensor, - timestep: int = None, - latents: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if not self.do_classifier_free_guidance: - return model_output - - if latents is None: - raise ValueError("APG requires `latents` to convert model output to denoised prediction (x0).") - - sigma = self.scheduler.sigmas[self.scheduler.step_index] - noise_pred = latents - sigma * model_output - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = self.normalized_guidance( - noise_pred_text, - noise_pred_uncond, - self.guidance_scale, - self.momentum_buffer, - self.adaptive_projected_guidance_rescale_factor, - ) - noise_pred = (latents - noise_pred) / sigma - - if self.guidance_rescale > 0.0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) - return noise_pred diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py new file mode 100644 index 000000000000..3c1ee293382d --- /dev/null +++ b/src/diffusers/guiders/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union + +from ..utils import is_torch_available + + +if is_torch_available(): + from .adaptive_projected_guidance import AdaptiveProjectedGuidance + from .auto_guidance import AutoGuidance + from .classifier_free_guidance import ClassifierFreeGuidance + from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance + from .skip_layer_guidance import SkipLayerGuidance + from .smoothed_energy_guidance import SmoothedEnergyGuidance + from .tangential_classifier_free_guidance import TangentialClassifierFreeGuidance + + GuiderType = Union[AdaptiveProjectedGuidance, AutoGuidance, ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, SkipLayerGuidance, SmoothedEnergyGuidance, TangentialClassifierFreeGuidance] diff --git a/src/diffusers/guiders/adaptive_projected_guidance.py b/src/diffusers/guiders/adaptive_projected_guidance.py new file mode 100644 index 000000000000..ef2f3f2c8420 --- /dev/null +++ b/src/diffusers/guiders/adaptive_projected_guidance.py @@ -0,0 +1,184 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple + +import torch + +from .guider_utils import BaseGuidance, rescale_noise_cfg + +if TYPE_CHECKING: + from ..modular_pipelines.modular_pipeline import BlockState + + +class AdaptiveProjectedGuidance(BaseGuidance): + """ + Adaptive Projected Guidance (APG): https://huggingface.co/papers/2410.02416 + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + adaptive_projected_guidance_momentum (`float`, defaults to `None`): + The momentum parameter for the adaptive projected guidance. Disabled if set to `None`. + adaptive_projected_guidance_rescale (`float`, defaults to `15.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + def __init__( + self, + guidance_scale: float = 7.5, + adaptive_projected_guidance_momentum: Optional[float] = None, + adaptive_projected_guidance_rescale: float = 15.0, + eta: float = 1.0, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum + self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale + self.eta = eta + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + self.momentum_buffer = None + + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + + if self._step == 0: + if self.adaptive_projected_guidance_momentum is not None: + self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum) + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for i in range(self.num_conditions): + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batches.append(data_batch) + return data_batches + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if not self._is_apg_enabled(): + pred = pred_cond + else: + pred = normalized_guidance( + pred_cond, + pred_uncond, + self.guidance_scale, + self.momentum_buffer, + self.eta, + self.adaptive_projected_guidance_rescale, + self.use_original_formulation, + ) + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred, {} + + @property + def is_conditional(self) -> bool: + return self._count_prepared == 1 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_apg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_apg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close + + +class MomentumBuffer: + def __init__(self, momentum: float): + self.momentum = momentum + self.running_average = 0 + + def update(self, update_value: torch.Tensor): + new_average = self.momentum * self.running_average + self.running_average = update_value + new_average + + +def normalized_guidance( + pred_cond: torch.Tensor, + pred_uncond: torch.Tensor, + guidance_scale: float, + momentum_buffer: Optional[MomentumBuffer] = None, + eta: float = 1.0, + norm_threshold: float = 0.0, + use_original_formulation: bool = False, +): + diff = pred_cond - pred_uncond + dim = [-i for i in range(1, len(diff.shape))] + + if momentum_buffer is not None: + momentum_buffer.update(diff) + diff = momentum_buffer.running_average + + if norm_threshold > 0: + ones = torch.ones_like(diff) + diff_norm = diff.norm(p=2, dim=dim, keepdim=True) + scale_factor = torch.minimum(ones, norm_threshold / diff_norm) + diff = diff * scale_factor + + v0, v1 = diff.double(), pred_cond.double() + v1 = torch.nn.functional.normalize(v1, dim=dim) + v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1 + v0_orthogonal = v0 - v0_parallel + diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff) + normalized_update = diff_orthogonal + eta * diff_parallel + + pred = pred_cond if use_original_formulation else pred_uncond + pred = pred + guidance_scale * normalized_update + + return pred diff --git a/src/diffusers/guiders/auto_guidance.py b/src/diffusers/guiders/auto_guidance.py new file mode 100644 index 000000000000..791cc582add2 --- /dev/null +++ b/src/diffusers/guiders/auto_guidance.py @@ -0,0 +1,177 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple + +import torch + +from ..hooks import HookRegistry, LayerSkipConfig +from ..hooks.layer_skip import _apply_layer_skip_hook +from .guider_utils import BaseGuidance, rescale_noise_cfg + +if TYPE_CHECKING: + from ..modular_pipelines.modular_pipeline import BlockState + + +class AutoGuidance(BaseGuidance): + """ + AutoGuidance: https://huggingface.co/papers/2406.02507 + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + auto_guidance_layers (`int` or `List[int]`, *optional*): + The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not + provided, `skip_layer_config` must be provided. + auto_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*): + The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of + `LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided. + dropout (`float`, *optional*): + The dropout probability for autoguidance on the enabled skip layers (either with `auto_guidance_layers` or + `auto_guidance_config`). If not provided, the dropout probability will be set to 1.0. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + def __init__( + self, + guidance_scale: float = 7.5, + auto_guidance_layers: Optional[Union[int, List[int]]] = None, + auto_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None, + dropout: Optional[float] = None, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.auto_guidance_layers = auto_guidance_layers + self.auto_guidance_config = auto_guidance_config + self.dropout = dropout + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + if auto_guidance_layers is None and auto_guidance_config is None: + raise ValueError( + "Either `auto_guidance_layers` or `auto_guidance_config` must be provided to enable Skip Layer Guidance." + ) + if auto_guidance_layers is not None and auto_guidance_config is not None: + raise ValueError("Only one of `auto_guidance_layers` or `auto_guidance_config` can be provided.") + if (dropout is None and auto_guidance_layers is not None) or (dropout is not None and auto_guidance_layers is None): + raise ValueError("`dropout` must be provided if `auto_guidance_layers` is provided.") + + if auto_guidance_layers is not None: + if isinstance(auto_guidance_layers, int): + auto_guidance_layers = [auto_guidance_layers] + if not isinstance(auto_guidance_layers, list): + raise ValueError( + f"Expected `auto_guidance_layers` to be an int or a list of ints, but got {type(auto_guidance_layers)}." + ) + auto_guidance_config = [LayerSkipConfig(layer, fqn="auto", dropout=dropout) for layer in auto_guidance_layers] + + if isinstance(auto_guidance_config, LayerSkipConfig): + auto_guidance_config = [auto_guidance_config] + + if not isinstance(auto_guidance_config, list): + raise ValueError( + f"Expected `auto_guidance_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(auto_guidance_config)}." + ) + + self.auto_guidance_config = auto_guidance_config + self._auto_guidance_hook_names = [f"AutoGuidance_{i}" for i in range(len(self.auto_guidance_config))] + + def prepare_models(self, denoiser: torch.nn.Module) -> None: + self._count_prepared += 1 + if self._is_ag_enabled() and self.is_unconditional: + for name, config in zip(self._auto_guidance_hook_names, self.auto_guidance_config): + _apply_layer_skip_hook(denoiser, config, name=name) + + def cleanup_models(self, denoiser: torch.nn.Module) -> None: + if self._is_ag_enabled() and self.is_unconditional: + for name in self._auto_guidance_hook_names: + registry = HookRegistry.check_if_exists_or_initialize(denoiser) + registry.remove_hook(name, recurse=True) + + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for i in range(self.num_conditions): + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batches.append(data_batch) + return data_batches + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if not self._is_ag_enabled(): + pred = pred_cond + else: + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred, {} + + @property + def is_conditional(self) -> bool: + return self._count_prepared == 1 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_ag_enabled(): + num_conditions += 1 + return num_conditions + + def _is_ag_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py new file mode 100644 index 000000000000..a459e51cd083 --- /dev/null +++ b/src/diffusers/guiders/classifier_free_guidance.py @@ -0,0 +1,132 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple + +import torch + +from .guider_utils import BaseGuidance, rescale_noise_cfg + +if TYPE_CHECKING: + from ..modular_pipelines.modular_pipeline import BlockState + + +class ClassifierFreeGuidance(BaseGuidance): + """ + Classifier-free guidance (CFG): https://huggingface.co/papers/2207.12598 + + CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by + jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during + inference. This allows the model to tradeoff between generation quality and sample diversity. + The original paper proposes scaling and shifting the conditional distribution based on the difference between + conditional and unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)] + + Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen + paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in + theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)] + + The intution behind the original formulation can be thought of as moving the conditional distribution estimates + further away from the unconditional distribution estimates, while the diffusers-native implementation can be + thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of + the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.) + + The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the + paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time. + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + def __init__( + self, guidance_scale: float = 7.5, guidance_rescale: float = 0.0, use_original_formulation: bool = False, start: float = 0.0, stop: float = 1.0 + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for i in range(self.num_conditions): + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batches.append(data_batch) + return data_batches + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if not self._is_cfg_enabled(): + pred = pred_cond + else: + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred, {} + + @property + def is_conditional(self) -> bool: + return self._count_prepared == 1 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_cfg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_cfg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close diff --git a/src/diffusers/guiders/classifier_free_zero_star_guidance.py b/src/diffusers/guiders/classifier_free_zero_star_guidance.py new file mode 100644 index 000000000000..a722f2605036 --- /dev/null +++ b/src/diffusers/guiders/classifier_free_zero_star_guidance.py @@ -0,0 +1,148 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple + +import torch + +from .guider_utils import BaseGuidance, rescale_noise_cfg + +if TYPE_CHECKING: + from ..modular_pipelines.modular_pipeline import BlockState + + +class ClassifierFreeZeroStarGuidance(BaseGuidance): + """ + Classifier-free Zero* (CFG-Zero*): https://huggingface.co/papers/2503.18886 + + This is an implementation of the Classifier-Free Zero* guidance technique, which is a variant of classifier-free + guidance. It proposes zero initialization of the noise predictions for the first few steps of the diffusion + process, and also introduces an optimal rescaling factor for the noise predictions, which can help in improving the + quality of generated images. + + The authors of the paper suggest setting zero initialization in the first 4% of the inference steps. + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + zero_init_steps (`int`, defaults to `1`): + The number of inference steps for which the noise predictions are zeroed out (see Section 4.2). + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.01`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `0.2`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + def __init__( + self, + guidance_scale: float = 7.5, + zero_init_steps: int = 1, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.zero_init_steps = zero_init_steps + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for i in range(self.num_conditions): + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batches.append(data_batch) + return data_batches + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if self._step < self.zero_init_steps: + pred = torch.zeros_like(pred_cond) + elif not self._is_cfg_enabled(): + pred = pred_cond + else: + pred_cond_flat = pred_cond.flatten(1) + pred_uncond_flat = pred_uncond.flatten(1) + alpha = cfg_zero_star_scale(pred_cond_flat, pred_uncond_flat) + alpha = alpha.view(-1, *(1,) * (len(pred_cond.shape) - 1)) + pred_uncond = pred_uncond * alpha + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred, {} + + @property + def is_conditional(self) -> bool: + return self._count_prepared == 1 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_cfg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_cfg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close + + +def cfg_zero_star_scale(cond: torch.Tensor, uncond: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: + cond_dtype = cond.dtype + cond = cond.float() + uncond = uncond.float() + dot_product = torch.sum(cond * uncond, dim=1, keepdim=True) + squared_norm = torch.sum(uncond**2, dim=1, keepdim=True) + eps + # st_star = v_cond^T * v_uncond / ||v_uncond||^2 + scale = dot_product / squared_norm + return scale.to(dtype=cond_dtype) diff --git a/src/diffusers/guiders/entropy_rectifying_guidance.py b/src/diffusers/guiders/entropy_rectifying_guidance.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py new file mode 100644 index 000000000000..e8e873f5c88f --- /dev/null +++ b/src/diffusers/guiders/guider_utils.py @@ -0,0 +1,215 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union + +import torch + +from ..utils import get_logger + + +if TYPE_CHECKING: + from ..modular_pipelines.modular_pipeline import BlockState + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +class BaseGuidance: + r"""Base class providing the skeleton for implementing guidance techniques.""" + + _input_predictions = None + _identifier_key = "__guidance_identifier__" + + def __init__(self, start: float = 0.0, stop: float = 1.0): + self._start = start + self._stop = stop + self._step: int = None + self._num_inference_steps: int = None + self._timestep: torch.LongTensor = None + self._count_prepared = 0 + self._input_fields: Dict[str, Union[str, Tuple[str, str]]] = None + self._enabled = True + + if not (0.0 <= start < 1.0): + raise ValueError( + f"Expected `start` to be between 0.0 and 1.0, but got {start}." + ) + if not (start <= stop <= 1.0): + raise ValueError( + f"Expected `stop` to be between {start} and 1.0, but got {stop}." + ) + + if self._input_predictions is None or not isinstance(self._input_predictions, list): + raise ValueError( + "`_input_predictions` must be a list of required prediction names for the guidance technique." + ) + + def disable(self): + self._enabled = False + + def enable(self): + self._enabled = True + + def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None: + self._step = step + self._num_inference_steps = num_inference_steps + self._timestep = timestep + self._count_prepared = 0 + + def set_input_fields(self, **kwargs: Dict[str, Union[str, Tuple[str, str]]]) -> None: + """ + Set the input fields for the guidance technique. The input fields are used to specify the names of the + returned attributes containing the prepared data after `prepare_inputs` is called. The prepared data is + obtained from the values of the provided keyword arguments to this method. + + Args: + **kwargs (`Dict[str, Union[str, Tuple[str, str]]]`): + A dictionary where the keys are the names of the fields that will be used to store the data once + it is prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, + which is used to look up the required data provided for preparation. + + If a string is provided, it will be used as the conditional data (or unconditional if used with + a guidance method that requires it). If a tuple of length 2 is provided, the first element must + be the conditional data identifier and the second element must be the unconditional data identifier + or None. + + Example: + + ``` + data = {"prompt_embeds": , "negative_prompt_embeds": , "latents": } + + BaseGuidance.set_input_fields( + latents="latents", + prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), + ) + ``` + """ + for key, value in kwargs.items(): + is_string = isinstance(value, str) + is_tuple_of_str_with_len_2 = isinstance(value, tuple) and len(value) == 2 and all(isinstance(v, str) for v in value) + if not (is_string or is_tuple_of_str_with_len_2): + raise ValueError( + f"Expected `set_input_fields` to be called with a string or a tuple of string with length 2, but got {type(value)} for key {key}." + ) + self._input_fields = kwargs + + def prepare_models(self, denoiser: torch.nn.Module) -> None: + """ + Prepares the models for the guidance technique on a given batch of data. This method should be overridden in + subclasses to implement specific model preparation logic. + """ + self._count_prepared += 1 + + def cleanup_models(self, denoiser: torch.nn.Module) -> None: + """ + Cleans up the models for the guidance technique after a given batch of data. This method should be overridden in + subclasses to implement specific model cleanup logic. It is useful for removing any hooks or other stateful + modifications made during `prepare_models`. + """ + pass + + def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.") + + def __call__(self, data: List["BlockState"]) -> Any: + if not all(hasattr(d, "noise_pred") for d in data): + raise ValueError("Expected all data to have `noise_pred` attribute.") + if len(data) != self.num_conditions: + raise ValueError( + f"Expected {self.num_conditions} data items, but got {len(data)}. Please check the input data." + ) + forward_inputs = {getattr(d, self._identifier_key): d.noise_pred for d in data} + return self.forward(**forward_inputs) + + def forward(self, *args, **kwargs) -> Any: + raise NotImplementedError("BaseGuidance::forward must be implemented in subclasses.") + + @property + def is_conditional(self) -> bool: + raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.") + + @property + def is_unconditional(self) -> bool: + return not self.is_conditional + + @property + def num_conditions(self) -> int: + raise NotImplementedError("BaseGuidance::num_conditions must be implemented in subclasses.") + + @classmethod + def _prepare_batch(cls, input_fields: Dict[str, Union[str, Tuple[str, str]]], data: "BlockState", tuple_index: int, identifier: str) -> "BlockState": + """ + Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of + the `BaseGuidance` class. It prepares the batch based on the provided tuple index. + + Args: + input_fields (`Dict[str, Union[str, Tuple[str, str]]]`): + A dictionary where the keys are the names of the fields that will be used to store the data once + it is prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, + which is used to look up the required data provided for preparation. + If a string is provided, it will be used as the conditional data (or unconditional if used with + a guidance method that requires it). If a tuple of length 2 is provided, the first element must + be the conditional data identifier and the second element must be the unconditional data identifier + or None. + data (`BlockState`): + The input data to be prepared. + tuple_index (`int`): + The index to use when accessing input fields that are tuples. + + Returns: + `BlockState`: The prepared batch of data. + """ + from ..modular_pipelines.modular_pipeline import BlockState + + if input_fields is None: + raise ValueError("Input fields cannot be None. Please pass `input_fields` to `prepare_inputs` or call `set_input_fields` before preparing inputs.") + data_batch = {} + for key, value in input_fields.items(): + try: + if isinstance(value, str): + data_batch[key] = getattr(data, value) + elif isinstance(value, tuple): + data_batch[key] = getattr(data, value[tuple_index]) + else: + # We've already checked that value is a string or a tuple of strings with length 2 + pass + except AttributeError: + logger.debug(f"`data` does not have attribute(s) {value}, skipping.") + data_batch[cls._identifier_key] = identifier + return BlockState(**data_batch) + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py new file mode 100644 index 000000000000..7c19f6391f41 --- /dev/null +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -0,0 +1,251 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple + +import torch + +from ..hooks import HookRegistry, LayerSkipConfig +from ..hooks.layer_skip import _apply_layer_skip_hook +from .guider_utils import BaseGuidance, rescale_noise_cfg + +if TYPE_CHECKING: + from ..modular_pipelines.modular_pipeline import BlockState + + +class SkipLayerGuidance(BaseGuidance): + """ + Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5 + + Spatio-Temporal Guidance (STG): https://huggingface.co/papers/2411.18664 + + SLG was introduced by StabilityAI for improving structure and anotomy coherence in generated images. It works by + skipping the forward pass of specified transformer blocks during the denoising process on an additional conditional + batch of data, apart from the conditional and unconditional batches already used in CFG + ([~guiders.classifier_free_guidance.ClassifierFreeGuidance]), and then scaling and shifting the CFG predictions + based on the difference between conditional without skipping and conditional with skipping predictions. + + The intution behind SLG can be thought of as moving the CFG predicted distribution estimates further away from + worse versions of the conditional distribution estimates (because skipping layers is equivalent to using a worse + version of the model for the conditional prediction). + + STG is an improvement and follow-up work combining ideas from SLG, PAG and similar techniques for improving + generation quality in video diffusion models. + + Additional reading: + - [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507) + + The values for `skip_layer_guidance_scale`, `skip_layer_guidance_start`, and `skip_layer_guidance_stop` are + defaulted to the recommendations by StabilityAI for Stable Diffusion 3.5 Medium. + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + skip_layer_guidance_scale (`float`, defaults to `2.8`): + The scale parameter for skip layer guidance. Anatomy and structure coherence may improve with higher + values, but it may also lead to overexposure and saturation. + skip_layer_guidance_start (`float`, defaults to `0.01`): + The fraction of the total number of denoising steps after which skip layer guidance starts. + skip_layer_guidance_stop (`float`, defaults to `0.2`): + The fraction of the total number of denoising steps after which skip layer guidance stops. + skip_layer_guidance_layers (`int` or `List[int]`, *optional*): + The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not + provided, `skip_layer_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion + 3.5 Medium. + skip_layer_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*): + The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of + `LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.01`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `0.2`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"] + + def __init__( + self, + guidance_scale: float = 7.5, + skip_layer_guidance_scale: float = 2.8, + skip_layer_guidance_start: float = 0.01, + skip_layer_guidance_stop: float = 0.2, + skip_layer_guidance_layers: Optional[Union[int, List[int]]] = None, + skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.skip_layer_guidance_scale = skip_layer_guidance_scale + self.skip_layer_guidance_start = skip_layer_guidance_start + self.skip_layer_guidance_stop = skip_layer_guidance_stop + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + if not (0.0 <= skip_layer_guidance_start < 1.0): + raise ValueError( + f"Expected `skip_layer_guidance_start` to be between 0.0 and 1.0, but got {skip_layer_guidance_start}." + ) + if not (skip_layer_guidance_start <= skip_layer_guidance_stop <= 1.0): + raise ValueError( + f"Expected `skip_layer_guidance_stop` to be between 0.0 and 1.0, but got {skip_layer_guidance_stop}." + ) + + if skip_layer_guidance_layers is None and skip_layer_config is None: + raise ValueError( + "Either `skip_layer_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance." + ) + if skip_layer_guidance_layers is not None and skip_layer_config is not None: + raise ValueError("Only one of `skip_layer_guidance_layers` or `skip_layer_config` can be provided.") + + if skip_layer_guidance_layers is not None: + if isinstance(skip_layer_guidance_layers, int): + skip_layer_guidance_layers = [skip_layer_guidance_layers] + if not isinstance(skip_layer_guidance_layers, list): + raise ValueError( + f"Expected `skip_layer_guidance_layers` to be an int or a list of ints, but got {type(skip_layer_guidance_layers)}." + ) + skip_layer_config = [LayerSkipConfig(layer, fqn="auto") for layer in skip_layer_guidance_layers] + + if isinstance(skip_layer_config, LayerSkipConfig): + skip_layer_config = [skip_layer_config] + + if not isinstance(skip_layer_config, list): + raise ValueError( + f"Expected `skip_layer_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(skip_layer_config)}." + ) + + self.skip_layer_config = skip_layer_config + self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))] + + def prepare_models(self, denoiser: torch.nn.Module) -> None: + self._count_prepared += 1 + if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1: + for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config): + _apply_layer_skip_hook(denoiser, config, name=name) + + def cleanup_models(self, denoiser: torch.nn.Module) -> None: + if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1: + registry = HookRegistry.check_if_exists_or_initialize(denoiser) + # Remove the hooks after inference + for hook_name in self._skip_layer_hook_names: + registry.remove_hook(hook_name, recurse=True) + + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + + if self.num_conditions == 1: + tuple_indices = [0] + input_predictions = ["pred_cond"] + elif self.num_conditions == 2: + tuple_indices = [0, 1] + input_predictions = ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"] + else: + tuple_indices = [0, 1, 0] + input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"] + data_batches = [] + for i in range(self.num_conditions): + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i]) + data_batches.append(data_batch) + return data_batches + + def forward( + self, + pred_cond: torch.Tensor, + pred_uncond: Optional[torch.Tensor] = None, + pred_cond_skip: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + pred = None + + if not self._is_cfg_enabled() and not self._is_slg_enabled(): + pred = pred_cond + elif not self._is_cfg_enabled(): + shift = pred_cond - pred_cond_skip + pred = pred_cond if self.use_original_formulation else pred_cond_skip + pred = pred + self.skip_layer_guidance_scale * shift + elif not self._is_slg_enabled(): + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + else: + shift = pred_cond - pred_uncond + shift_skip = pred_cond - pred_cond_skip + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred, {} + + @property + def is_conditional(self) -> bool: + return self._count_prepared == 1 or self._count_prepared == 3 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_cfg_enabled(): + num_conditions += 1 + if self._is_slg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_cfg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close + + def _is_slg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps) + skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps) + is_within_range = skip_start_step < self._step < skip_stop_step + + is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0) + + return is_within_range and not is_zero diff --git a/src/diffusers/guiders/smoothed_energy_guidance.py b/src/diffusers/guiders/smoothed_energy_guidance.py new file mode 100644 index 000000000000..3986da913f82 --- /dev/null +++ b/src/diffusers/guiders/smoothed_energy_guidance.py @@ -0,0 +1,244 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple + +import torch + +from ..hooks import HookRegistry +from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook +from .guider_utils import BaseGuidance, rescale_noise_cfg + +if TYPE_CHECKING: + from ..modular_pipelines.modular_pipeline import BlockState + + +class SmoothedEnergyGuidance(BaseGuidance): + """ + Smoothed Energy Guidance (SEG): https://huggingface.co/papers/2408.00760 + + SEG is only supported as an experimental prototype feature for now, so the implementation may be modified + in the future without warning or guarantee of reproducibility. This implementation assumes: + - Generated images are square (height == width) + - The model does not combine different modalities together (e.g., text and image latent streams are + not combined together such as Flux) + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + seg_guidance_scale (`float`, defaults to `3.0`): + The scale parameter for smoothed energy guidance. Anatomy and structure coherence may improve with higher + values, but it may also lead to overexposure and saturation. + seg_blur_sigma (`float`, defaults to `9999999.0`): + The amount by which we blur the attention weights. Setting this value greater than 9999.0 results in + infinite blur, which means uniform queries. Controlling it exponentially is empirically effective. + seg_blur_threshold_inf (`float`, defaults to `9999.0`): + The threshold above which the blur is considered infinite. + seg_guidance_start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which smoothed energy guidance starts. + seg_guidance_stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which smoothed energy guidance stops. + seg_guidance_layers (`int` or `List[int]`, *optional*): + The layer indices to apply smoothed energy guidance to. Can be a single integer or a list of integers. If not + provided, `seg_guidance_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion + 3.5 Medium. + seg_guidance_config (`SmoothedEnergyGuidanceConfig` or `List[SmoothedEnergyGuidanceConfig]`, *optional*): + The configuration for the smoothed energy layer guidance. Can be a single `SmoothedEnergyGuidanceConfig` or a list of + `SmoothedEnergyGuidanceConfig`. If not provided, `seg_guidance_layers` must be provided. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.01`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `0.2`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"] + + def __init__( + self, + guidance_scale: float = 7.5, + seg_guidance_scale: float = 2.8, + seg_blur_sigma: float = 9999999.0, + seg_blur_threshold_inf: float = 9999.0, + seg_guidance_start: float = 0.0, + seg_guidance_stop: float = 1.0, + seg_guidance_layers: Optional[Union[int, List[int]]] = None, + seg_guidance_config: Union[SmoothedEnergyGuidanceConfig, List[SmoothedEnergyGuidanceConfig]] = None, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.seg_guidance_scale = seg_guidance_scale + self.seg_blur_sigma = seg_blur_sigma + self.seg_blur_threshold_inf = seg_blur_threshold_inf + self.seg_guidance_start = seg_guidance_start + self.seg_guidance_stop = seg_guidance_stop + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + if not (0.0 <= seg_guidance_start < 1.0): + raise ValueError( + f"Expected `seg_guidance_start` to be between 0.0 and 1.0, but got {seg_guidance_start}." + ) + if not (seg_guidance_start <= seg_guidance_stop <= 1.0): + raise ValueError( + f"Expected `seg_guidance_stop` to be between 0.0 and 1.0, but got {seg_guidance_stop}." + ) + + if seg_guidance_layers is None and seg_guidance_config is None: + raise ValueError( + "Either `seg_guidance_layers` or `seg_guidance_config` must be provided to enable Smoothed Energy Guidance." + ) + if seg_guidance_layers is not None and seg_guidance_config is not None: + raise ValueError("Only one of `seg_guidance_layers` or `seg_guidance_config` can be provided.") + + if seg_guidance_layers is not None: + if isinstance(seg_guidance_layers, int): + seg_guidance_layers = [seg_guidance_layers] + if not isinstance(seg_guidance_layers, list): + raise ValueError( + f"Expected `seg_guidance_layers` to be an int or a list of ints, but got {type(seg_guidance_layers)}." + ) + seg_guidance_config = [SmoothedEnergyGuidanceConfig(layer, fqn="auto") for layer in seg_guidance_layers] + + if isinstance(seg_guidance_config, SmoothedEnergyGuidanceConfig): + seg_guidance_config = [seg_guidance_config] + + if not isinstance(seg_guidance_config, list): + raise ValueError( + f"Expected `seg_guidance_config` to be a SmoothedEnergyGuidanceConfig or a list of SmoothedEnergyGuidanceConfig, but got {type(seg_guidance_config)}." + ) + + self.seg_guidance_config = seg_guidance_config + self._seg_layer_hook_names = [f"SmoothedEnergyGuidance_{i}" for i in range(len(self.seg_guidance_config))] + + def prepare_models(self, denoiser: torch.nn.Module) -> None: + if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1: + for name, config in zip(self._seg_layer_hook_names, self.seg_guidance_config): + _apply_smoothed_energy_guidance_hook(denoiser, config, self.seg_blur_sigma, name=name) + + def cleanup_models(self, denoiser: torch.nn.Module): + if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1: + registry = HookRegistry.check_if_exists_or_initialize(denoiser) + # Remove the hooks after inference + for hook_name in self._seg_layer_hook_names: + registry.remove_hook(hook_name, recurse=True) + + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + + if self.num_conditions == 1: + tuple_indices = [0] + input_predictions = ["pred_cond"] + elif self.num_conditions == 2: + tuple_indices = [0, 1] + input_predictions = ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_seg"] + else: + tuple_indices = [0, 1, 0] + input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"] + data_batches = [] + for i in range(self.num_conditions): + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i]) + data_batches.append(data_batch) + return data_batches + + def forward( + self, + pred_cond: torch.Tensor, + pred_uncond: Optional[torch.Tensor] = None, + pred_cond_seg: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + pred = None + + if not self._is_cfg_enabled() and not self._is_seg_enabled(): + pred = pred_cond + elif not self._is_cfg_enabled(): + shift = pred_cond - pred_cond_seg + pred = pred_cond if self.use_original_formulation else pred_cond_seg + pred = pred + self.seg_guidance_scale * shift + elif not self._is_seg_enabled(): + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + else: + shift = pred_cond - pred_uncond + shift_seg = pred_cond - pred_cond_seg + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + self.seg_guidance_scale * shift_seg + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred, {} + + @property + def is_conditional(self) -> bool: + return self._count_prepared == 1 or self._count_prepared == 3 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_cfg_enabled(): + num_conditions += 1 + if self._is_seg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_cfg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close + + def _is_seg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self.seg_guidance_start * self._num_inference_steps) + skip_stop_step = int(self.seg_guidance_stop * self._num_inference_steps) + is_within_range = skip_start_step < self._step < skip_stop_step + + is_zero = math.isclose(self.seg_guidance_scale, 0.0) + + return is_within_range and not is_zero diff --git a/src/diffusers/guiders/tangential_classifier_free_guidance.py b/src/diffusers/guiders/tangential_classifier_free_guidance.py new file mode 100644 index 000000000000..017693fd9f07 --- /dev/null +++ b/src/diffusers/guiders/tangential_classifier_free_guidance.py @@ -0,0 +1,137 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple + +import torch + +from .guider_utils import BaseGuidance, rescale_noise_cfg + +if TYPE_CHECKING: + from ..modular_pipelines.modular_pipeline import BlockState + + +class TangentialClassifierFreeGuidance(BaseGuidance): + """ + Tangential Classifier Free Guidance (TCFG): https://huggingface.co/papers/2503.18137 + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + def __init__( + self, + guidance_scale: float = 7.5, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for i in range(self.num_conditions): + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batches.append(data_batch) + return data_batches + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if not self._is_tcfg_enabled(): + pred = pred_cond + else: + pred = normalized_guidance(pred_cond, pred_uncond, self.guidance_scale, self.use_original_formulation) + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred, {} + + @property + def is_conditional(self) -> bool: + return self._num_outputs_prepared == 1 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_tcfg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_tcfg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close + + +def normalized_guidance(pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guidance_scale: float, use_original_formulation: bool = False) -> torch.Tensor: + cond_dtype = pred_cond.dtype + preds = torch.stack([pred_cond, pred_uncond], dim=1).float() + preds = preds.flatten(2) + U, S, Vh = torch.linalg.svd(preds, full_matrices=False) + Vh_modified = Vh.clone() + Vh_modified[:, 1] = 0 + + uncond_flat = pred_uncond.reshape(pred_uncond.size(0), 1, -1).float() + x_Vh = torch.matmul(uncond_flat, Vh.transpose(-2, -1)) + x_Vh_V = torch.matmul(x_Vh, Vh_modified) + pred_uncond = x_Vh_V.reshape(pred_uncond.shape).to(cond_dtype) + + pred = pred_cond if use_original_formulation else pred_uncond + shift = pred_cond - pred_uncond + pred = pred + guidance_scale * shift + + return pred diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 764ceb25b465..9d0e96e9e79e 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -5,5 +5,7 @@ from .faster_cache import FasterCacheConfig, apply_faster_cache from .group_offloading import apply_group_offloading from .hooks import HookRegistry, ModelHook + from .layer_skip import LayerSkipConfig, apply_layer_skip from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast + from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig diff --git a/src/diffusers/hooks/_common.py b/src/diffusers/hooks/_common.py new file mode 100644 index 000000000000..3d9c99e8189f --- /dev/null +++ b/src/diffusers/hooks/_common.py @@ -0,0 +1,43 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch + +from ..models.attention import FeedForward, LuminaFeedForward +from ..models.attention_processor import Attention, MochiAttention + + +_ATTENTION_CLASSES = (Attention, MochiAttention) +_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward) + +_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers") +_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",) +_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers") + +_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple( + { + *_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS, + *_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS, + *_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS, + } +) + + +def _get_submodule_from_fqn(module: torch.nn.Module, fqn: str) -> Optional[torch.nn.Module]: + for submodule_name, submodule in module.named_modules(): + if submodule_name == fqn: + return submodule + return None diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py new file mode 100644 index 000000000000..9043ffc41838 --- /dev/null +++ b/src/diffusers/hooks/_helpers.py @@ -0,0 +1,271 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Callable, Type + +from ..models.attention import BasicTransformerBlock +from ..models.attention_processor import AttnProcessor2_0 +from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock +from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor, CogView4TransformerBlock +from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock +from ..models.transformers.transformer_hunyuan_video import ( + HunyuanVideoSingleTransformerBlock, + HunyuanVideoTokenReplaceSingleTransformerBlock, + HunyuanVideoTokenReplaceTransformerBlock, + HunyuanVideoTransformerBlock, +) +from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock +from ..models.transformers.transformer_mochi import MochiTransformerBlock +from ..models.transformers.transformer_wan import WanTransformerBlock + + +@dataclass +class AttentionProcessorMetadata: + skip_processor_output_fn: Callable[[Any], Any] + + +@dataclass +class TransformerBlockMetadata: + skip_block_output_fn: Callable[[Any], Any] + return_hidden_states_index: int = None + return_encoder_hidden_states_index: int = None + + +class AttentionProcessorRegistry: + _registry = {} + + @classmethod + def register(cls, model_class: Type, metadata: AttentionProcessorMetadata): + cls._registry[model_class] = metadata + + @classmethod + def get(cls, model_class: Type) -> AttentionProcessorMetadata: + if model_class not in cls._registry: + raise ValueError(f"Model class {model_class} not registered.") + return cls._registry[model_class] + + +class TransformerBlockRegistry: + _registry = {} + + @classmethod + def register(cls, model_class: Type, metadata: TransformerBlockMetadata): + cls._registry[model_class] = metadata + + @classmethod + def get(cls, model_class: Type) -> TransformerBlockMetadata: + if model_class not in cls._registry: + raise ValueError(f"Model class {model_class} not registered.") + return cls._registry[model_class] + + +def _register_attention_processors_metadata(): + # AttnProcessor2_0 + AttentionProcessorRegistry.register( + model_class=AttnProcessor2_0, + metadata=AttentionProcessorMetadata( + skip_processor_output_fn=_skip_proc_output_fn_Attention_AttnProcessor2_0, + ), + ) + + # CogView4AttnProcessor + AttentionProcessorRegistry.register( + model_class=CogView4AttnProcessor, + metadata=AttentionProcessorMetadata( + skip_processor_output_fn=_skip_proc_output_fn_Attention_CogView4AttnProcessor, + ), + ) + + +def _register_transformer_blocks_metadata(): + # BasicTransformerBlock + TransformerBlockRegistry.register( + model_class=BasicTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_BasicTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + ), + ) + + # CogVideoX + TransformerBlockRegistry.register( + model_class=CogVideoXBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_CogVideoXBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + + # CogView4 + TransformerBlockRegistry.register( + model_class=CogView4TransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_CogView4TransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + + # Flux + TransformerBlockRegistry.register( + model_class=FluxTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_FluxTransformerBlock, + return_hidden_states_index=1, + return_encoder_hidden_states_index=0, + ), + ) + TransformerBlockRegistry.register( + model_class=FluxSingleTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_FluxSingleTransformerBlock, + return_hidden_states_index=1, + return_encoder_hidden_states_index=0, + ), + ) + + # HunyuanVideo + TransformerBlockRegistry.register( + model_class=HunyuanVideoTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + TransformerBlockRegistry.register( + model_class=HunyuanVideoSingleTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_HunyuanVideoSingleTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + TransformerBlockRegistry.register( + model_class=HunyuanVideoTokenReplaceTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + TransformerBlockRegistry.register( + model_class=HunyuanVideoTokenReplaceSingleTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + + # LTXVideo + TransformerBlockRegistry.register( + model_class=LTXVideoTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_LTXVideoTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + ), + ) + + # Mochi + TransformerBlockRegistry.register( + model_class=MochiTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_MochiTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + + # Wan + TransformerBlockRegistry.register( + model_class=WanTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_WanTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + ), + ) + + +# fmt: off +def _skip_attention___ret___hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + return hidden_states + + +def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + encoder_hidden_states = kwargs.get("encoder_hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + if encoder_hidden_states is None and len(args) > 1: + encoder_hidden_states = args[1] + return hidden_states, encoder_hidden_states + + +_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states +_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states + + +def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + return hidden_states + + +def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + encoder_hidden_states = kwargs.get("encoder_hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + if encoder_hidden_states is None and len(args) > 1: + encoder_hidden_states = args[1] + return hidden_states, encoder_hidden_states + + +def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + encoder_hidden_states = kwargs.get("encoder_hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + if encoder_hidden_states is None and len(args) > 1: + encoder_hidden_states = args[1] + return encoder_hidden_states, hidden_states + + +_skip_block_output_fn_BasicTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states +_skip_block_output_fn_CogVideoXBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_CogView4TransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_FluxTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states +_skip_block_output_fn_FluxSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states +_skip_block_output_fn_HunyuanVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_HunyuanVideoSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_LTXVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states +_skip_block_output_fn_MochiTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_WanTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states +# fmt: on + + +_register_attention_processors_metadata() +_register_transformer_blocks_metadata() diff --git a/src/diffusers/hooks/layer_skip.py b/src/diffusers/hooks/layer_skip.py new file mode 100644 index 000000000000..65a99464ba2f --- /dev/null +++ b/src/diffusers/hooks/layer_skip.py @@ -0,0 +1,231 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Callable, List, Optional + +import torch + +from ..utils import get_logger +from ..utils.torch_utils import unwrap_module +from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES, _get_submodule_from_fqn +from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry +from .hooks import HookRegistry, ModelHook + + +logger = get_logger(__name__) # pylint: disable=invalid-name + +_LAYER_SKIP_HOOK = "layer_skip_hook" + + +# Aryan/YiYi TODO: we need to make guider class a config mixin so I think this is not needed +# either remove or make it serializable +@dataclass +class LayerSkipConfig: + r""" + Configuration for skipping internal transformer blocks when executing a transformer model. + + Args: + indices (`List[int]`): + The indices of the layer to skip. This is typically the first layer in the transformer block. + fqn (`str`, defaults to `"auto"`): + The fully qualified name identifying the stack of transformer blocks. Typically, this is + `transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`. + For automatic detection, set this to `"auto"`. + "auto" only works on DiT models. For UNet models, you must provide the correct fqn. + skip_attention (`bool`, defaults to `True`): + Whether to skip attention blocks. + skip_ff (`bool`, defaults to `True`): + Whether to skip feed-forward blocks. + skip_attention_scores (`bool`, defaults to `False`): + Whether to skip attention score computation in the attention blocks. This is equivalent to using `value` + projections as the output of scaled dot product attention. + dropout (`float`, defaults to `1.0`): + The dropout probability for dropping the outputs of the skipped layers. By default, this is set to `1.0`, + meaning that the outputs of the skipped layers are completely ignored. If set to `0.0`, the outputs of the + skipped layers are fully retained, which is equivalent to not skipping any layers. + """ + + indices: List[int] + fqn: str = "auto" + skip_attention: bool = True + skip_attention_scores: bool = False + skip_ff: bool = True + dropout: float = 1.0 + + def __post_init__(self): + if not (0 <= self.dropout <= 1): + raise ValueError(f"Expected `dropout` to be between 0.0 and 1.0, but got {self.dropout}.") + if not math.isclose(self.dropout, 1.0) and self.skip_attention_scores: + raise ValueError( + "Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0." + ) + + +class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode): + def __torch_function__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + if func is torch.nn.functional.scaled_dot_product_attention: + value = kwargs.get("value", None) + if value is None: + value = args[2] + return value + return func(*args, **kwargs) + + +class AttentionProcessorSkipHook(ModelHook): + def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False, dropout: float = 1.0): + self.skip_processor_output_fn = skip_processor_output_fn + self.skip_attention_scores = skip_attention_scores + self.dropout = dropout + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + if self.skip_attention_scores: + if not math.isclose(self.dropout, 1.0): + raise ValueError( + "Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0." + ) + with AttentionScoreSkipFunctionMode(): + output = self.fn_ref.original_forward(*args, **kwargs) + else: + if math.isclose(self.dropout, 1.0): + output = self.skip_processor_output_fn(module, *args, **kwargs) + else: + output = self.fn_ref.original_forward(*args, **kwargs) + output = torch.nn.functional.dropout(output, p=self.dropout) + return output + + +class FeedForwardSkipHook(ModelHook): + def __init__(self, dropout: float): + super().__init__() + self.dropout = dropout + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + if math.isclose(self.dropout, 1.0): + output = kwargs.get("hidden_states", None) + if output is None: + output = kwargs.get("x", None) + if output is None and len(args) > 0: + output = args[0] + else: + output = self.fn_ref.original_forward(*args, **kwargs) + output = torch.nn.functional.dropout(output, p=self.dropout) + return output + + +class TransformerBlockSkipHook(ModelHook): + def __init__(self, dropout: float): + super().__init__() + self.dropout = dropout + + def initialize_hook(self, module): + self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__) + return module + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + if math.isclose(self.dropout, 1.0): + output = self._metadata.skip_block_output_fn(module, *args, **kwargs) + else: + output = self.fn_ref.original_forward(*args, **kwargs) + output = torch.nn.functional.dropout(output, p=self.dropout) + return output + +def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None: + r""" + Apply layer skipping to internal layers of a transformer. + + Args: + module (`torch.nn.Module`): + The transformer model to which the layer skip hook should be applied. + config (`LayerSkipConfig`): + The configuration for the layer skip hook. + + Example: + + ```python + >>> from diffusers import apply_layer_skip_hook, CogVideoXTransformer3DModel, LayerSkipConfig + >>> transformer = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) + >>> config = LayerSkipConfig(layer_index=[10, 20], fqn="transformer_blocks") + >>> apply_layer_skip_hook(transformer, config) + ``` + """ + _apply_layer_skip_hook(module, config) + + +def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, name: Optional[str] = None) -> None: + name = name or _LAYER_SKIP_HOOK + + if config.skip_attention and config.skip_attention_scores: + raise ValueError("Cannot set both `skip_attention` and `skip_attention_scores` to True. Please choose one.") + if not math.isclose(config.dropout, 1.0) and config.skip_attention_scores: + raise ValueError("Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0.") + + if config.fqn == "auto": + for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS: + if hasattr(module, identifier): + config.fqn = identifier + break + else: + raise ValueError( + "Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid " + "`fqn` (fully qualified name) that identifies a stack of transformer blocks." + ) + + transformer_blocks = _get_submodule_from_fqn(module, config.fqn) + if transformer_blocks is None or not isinstance(transformer_blocks, torch.nn.ModuleList): + raise ValueError( + f"Could not find {config.fqn} in the provided module, or configured `fqn` (fully qualified name) does not identify " + f"a `torch.nn.ModuleList`. Please provide a valid `fqn` that identifies a stack of transformer blocks." + ) + if len(config.indices) == 0: + raise ValueError("Layer index list is empty. Please provide a non-empty list of layer indices to skip.") + + blocks_found = False + for i, block in enumerate(transformer_blocks): + if i not in config.indices: + continue + + blocks_found = True + + if config.skip_attention and config.skip_ff: + logger.debug(f"Applying TransformerBlockSkipHook to '{config.fqn}.{i}'") + registry = HookRegistry.check_if_exists_or_initialize(block) + hook = TransformerBlockSkipHook(config.dropout) + registry.register_hook(hook, name) + + elif config.skip_attention or config.skip_attention_scores: + for submodule_name, submodule in block.named_modules(): + if isinstance(submodule, _ATTENTION_CLASSES) and not submodule.is_cross_attention: + logger.debug(f"Applying AttentionProcessorSkipHook to '{config.fqn}.{i}.{submodule_name}'") + output_fn = AttentionProcessorRegistry.get(submodule.processor.__class__).skip_processor_output_fn + registry = HookRegistry.check_if_exists_or_initialize(submodule) + hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores, config.dropout) + registry.register_hook(hook, name) + + if config.skip_ff: + for submodule_name, submodule in block.named_modules(): + if isinstance(submodule, _FEEDFORWARD_CLASSES): + logger.debug(f"Applying FeedForwardSkipHook to '{config.fqn}.{i}.{submodule_name}'") + registry = HookRegistry.check_if_exists_or_initialize(submodule) + hook = FeedForwardSkipHook(config.dropout) + registry.register_hook(hook, name) + + if not blocks_found: + raise ValueError( + f"Could not find any transformer blocks matching the provided indices {config.indices} and " + f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness." + ) diff --git a/src/diffusers/hooks/smoothed_energy_guidance_utils.py b/src/diffusers/hooks/smoothed_energy_guidance_utils.py new file mode 100644 index 000000000000..f0366e29887f --- /dev/null +++ b/src/diffusers/hooks/smoothed_energy_guidance_utils.py @@ -0,0 +1,158 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn.functional as F + +from ..utils import get_logger +from ._common import _ATTENTION_CLASSES, _get_submodule_from_fqn +from .hooks import HookRegistry, ModelHook + + +logger = get_logger(__name__) # pylint: disable=invalid-name + +_SMOOTHED_ENERGY_GUIDANCE_HOOK = "smoothed_energy_guidance_hook" + + +@dataclass +class SmoothedEnergyGuidanceConfig: + r""" + Configuration for skipping internal transformer blocks when executing a transformer model. + + Args: + indices (`List[int]`): + The indices of the layer to skip. This is typically the first layer in the transformer block. + fqn (`str`, defaults to `"auto"`): + The fully qualified name identifying the stack of transformer blocks. Typically, this is + `transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`. + For automatic detection, set this to `"auto"`. + "auto" only works on DiT models. For UNet models, you must provide the correct fqn. + _query_proj_identifiers (`List[str]`, defaults to `None`): + The identifiers for the query projection layers. Typically, these are `to_q`, `query`, or `q_proj`. + If `None`, `to_q` is used by default. + """ + + indices: List[int] + fqn: str = "auto" + _query_proj_identifiers: List[str] = None + + +class SmoothedEnergyGuidanceHook(ModelHook): + def __init__(self, blur_sigma: float = 1.0, blur_threshold_inf: float = 9999.9) -> None: + super().__init__() + self.blur_sigma = blur_sigma + self.blur_threshold_inf = blur_threshold_inf + + def post_forward(self, module: torch.nn.Module, output: torch.Tensor) -> torch.Tensor: + # Copied from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L172C31-L172C102 + kernel_size = math.ceil(6 * self.blur_sigma) + 1 - math.ceil(6 * self.blur_sigma) % 2 + smoothed_output = _gaussian_blur_2d(output, kernel_size, self.blur_sigma, self.blur_threshold_inf) + return smoothed_output + + +def _apply_smoothed_energy_guidance_hook(module: torch.nn.Module, config: SmoothedEnergyGuidanceConfig, blur_sigma: float, name: Optional[str] = None) -> None: + name = name or _SMOOTHED_ENERGY_GUIDANCE_HOOK + + if config.fqn == "auto": + for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS: + if hasattr(module, identifier): + config.fqn = identifier + break + else: + raise ValueError( + "Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid " + "`fqn` (fully qualified name) that identifies a stack of transformer blocks." + ) + + if config._query_proj_identifiers is None: + config._query_proj_identifiers = ["to_q"] + + transformer_blocks = _get_submodule_from_fqn(module, config.fqn) + blocks_found = False + for i, block in enumerate(transformer_blocks): + if i not in config.indices: + continue + + blocks_found = True + + for submodule_name, submodule in block.named_modules(): + if not isinstance(submodule, _ATTENTION_CLASSES) or submodule.is_cross_attention: + continue + for identifier in config._query_proj_identifiers: + query_proj = getattr(submodule, identifier, None) + if query_proj is None or not isinstance(query_proj, torch.nn.Linear): + continue + logger.debug( + f"Registering smoothed energy guidance hook on {config.fqn}.{i}.{submodule_name}.{identifier}" + ) + registry = HookRegistry.check_if_exists_or_initialize(query_proj) + hook = SmoothedEnergyGuidanceHook(blur_sigma) + registry.register_hook(hook, name) + + if not blocks_found: + raise ValueError( + f"Could not find any transformer blocks matching the provided indices {config.indices} and " + f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness." + ) + + +# Modified from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L71 +def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma_threshold_inf: float) -> torch.Tensor: + """ + This implementation assumes that the input query is for visual (image/videos) tokens to apply the 2D gaussian + blur. However, some models use joint text-visual token attention for which this may not be suitable. Additionally, + this implementation also assumes that the visual tokens come from a square image/video. In practice, despite + these assumptions, applying the 2D square gaussian blur on the query projections generates reasonable results + for Smoothed Energy Guidance. + + SEG is only supported as an experimental prototype feature for now, so the implementation may be modified + in the future without warning or guarantee of reproducibility. + """ + assert query.ndim == 3 + + is_inf = sigma > sigma_threshold_inf + batch_size, seq_len, embed_dim = query.shape + + seq_len_sqrt = int(math.sqrt(seq_len)) + num_square_tokens = seq_len_sqrt * seq_len_sqrt + query_slice = query[:, :num_square_tokens, :] + query_slice = query_slice.permute(0, 2, 1) + query_slice = query_slice.reshape(batch_size, embed_dim, seq_len_sqrt, seq_len_sqrt) + + if is_inf: + kernel_size = min(kernel_size, seq_len_sqrt - (seq_len_sqrt % 2 - 1)) + kernel_size_half = (kernel_size - 1) / 2 + + x = torch.linspace(-kernel_size_half, kernel_size_half, steps=kernel_size) + pdf = torch.exp(-0.5 * (x / sigma).pow(2)) + kernel1d = pdf / pdf.sum() + kernel1d = kernel1d.to(query) + kernel2d = torch.matmul(kernel1d[:, None], kernel1d[None, :]) + kernel2d = kernel2d.expand(embed_dim, 1, kernel2d.shape[0], kernel2d.shape[1]) + + padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] + query_slice = F.pad(query_slice, padding, mode="reflect") + query_slice = F.conv2d(query_slice, kernel2d, groups=embed_dim) + else: + query_slice[:] = query_slice.mean(dim=(-2, -1), keepdim=True) + + query_slice = query_slice.reshape(batch_size, embed_dim, num_square_tokens) + query_slice = query_slice.permute(0, 2, 1) + query[:, :num_square_tokens, :] = query_slice.clone() + + return query diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py new file mode 100644 index 000000000000..4499634d9fbd --- /dev/null +++ b/src/diffusers/modular_pipelines/__init__.py @@ -0,0 +1,84 @@ +from typing import TYPE_CHECKING + +from ..utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +# These modules contain pipelines from multiple libraries/frameworks +_dummy_objects = {} +_import_structure = {} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_pt_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_pt_objects)) +else: + _import_structure["modular_pipeline"] = [ + "ModularPipelineBlocks", + "ModularPipeline", + "PipelineBlock", + "AutoPipelineBlocks", + "SequentialPipelineBlocks", + "LoopSequentialPipelineBlocks", + "ModularLoader", + "PipelineState", + "BlockState", + ] + _import_structure["modular_pipeline_utils"] = [ + "ComponentSpec", + "ConfigSpec", + "InputParam", + "OutputParam", + ] + _import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoPipeline", "StableDiffusionXLModularLoader"] + _import_structure["components_manager"] = ["ComponentsManager"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_pt_objects import * # noqa F403 + else: + from .modular_pipeline import ( + AutoPipelineBlocks, + BlockState, + LoopSequentialPipelineBlocks, + ModularLoader, + ModularPipelineBlocks, + ModularPipeline, + PipelineBlock, + PipelineState, + SequentialPipelineBlocks, + ) + from .modular_pipeline_utils import ( + ComponentSpec, + ConfigSpec, + InputParam, + OutputParam, + ) + from .stable_diffusion_xl import ( + StableDiffusionXLAutoPipeline, + StableDiffusionXLModularLoader, + ) + from .components_manager import ComponentsManager +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py similarity index 51% rename from src/diffusers/pipelines/components_manager.py rename to src/diffusers/modular_pipelines/components_manager.py index 6d7665e29292..992353389b95 100644 --- a/src/diffusers/pipelines/components_manager.py +++ b/src/diffusers/modular_pipelines/components_manager.py @@ -26,6 +26,10 @@ logging, ) from ..models.modeling_utils import ModelMixin +from .modular_pipeline_utils import ComponentSpec + + +import uuid if is_accelerate_available(): @@ -229,54 +233,209 @@ def search_best_candidate(module_sizes, min_memory_offload): return hooks_to_offload + class ComponentsManager: def __init__(self): self.components = OrderedDict() - self.added_time = OrderedDict() # Store when components were added + self.added_time = OrderedDict() # Store when components were added + self.collections = OrderedDict() # collection_name -> set of component_names self.model_hooks = None self._auto_offload_enabled = False - def add(self, name, component): - if name in self.components: - logger.warning(f"Overriding existing component '{name}' in ComponentsManager") - self.components[name] = component - self.added_time[name] = time.time() + + def _lookup_ids(self, name=None, collection=None, load_id=None, components: OrderedDict = None): + """ + Lookup component_ids by name, collection, or load_id. + """ + if components is None: + components = self.components + + if name: + ids_by_name = set() + for component_id, component in components.items(): + comp_name = self._id_to_name(component_id) + if comp_name == name: + ids_by_name.add(component_id) + else: + ids_by_name = set(components.keys()) + if collection: + ids_by_collection = set() + for component_id, component in components.items(): + if component_id in self.collections[collection]: + ids_by_collection.add(component_id) + else: + ids_by_collection = set(components.keys()) + if load_id: + ids_by_load_id = set() + for name, component in components.items(): + if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id == load_id: + ids_by_load_id.add(name) + else: + ids_by_load_id = set(components.keys()) + ids = ids_by_name.intersection(ids_by_collection).intersection(ids_by_load_id) + return ids + + @staticmethod + def _id_to_name(component_id: str): + return "_".join(component_id.split("_")[:-1]) + + def add(self, name, component, collection: Optional[str] = None): + + component_id = f"{name}_{uuid.uuid4()}" + + # check for duplicated components + for comp_id, comp in self.components.items(): + if comp == component: + comp_name = self._id_to_name(comp_id) + if comp_name == name: + logger.warning( + f"component '{name}' already exists as '{comp_id}'" + ) + component_id = comp_id + break + else: + logger.warning( + f"Adding component '{name}' as '{component_id}', but it is duplicate of '{comp_id}'" + f"To remove a duplicate, call `components_manager.remove('')`." + ) + + + # check for duplicated load_id and warn (we do not delete for you) + if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null": + components_with_same_load_id = self._lookup_ids(load_id=component._diffusers_load_id) + components_with_same_load_id = [id for id in components_with_same_load_id if id != component_id] + + if components_with_same_load_id: + existing = ", ".join(components_with_same_load_id) + logger.warning( + f"Adding component '{component_id}', but it has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. " + f"To remove a duplicate, call `components_manager.remove('')`." + ) + + # add component to components manager + self.components[component_id] = component + self.added_time[component_id] = time.time() + + if collection: + if collection not in self.collections: + self.collections[collection] = set() + if not component_id in self.collections[collection]: + comp_ids_in_collection = self._lookup_ids(name=name, collection=collection) + for comp_id in comp_ids_in_collection: + logger.info(f"Removing existing {name} from collection '{collection}': {comp_id}") + self.remove(comp_id) + self.collections[collection].add(component_id) + logger.info(f"Added component '{name}' in collection '{collection}': {component_id}") + else: + logger.info(f"Added component '{name}' as '{component_id}'") + if self._auto_offload_enabled: - self.enable_auto_cpu_offload(self._auto_offload_device) + self.enable_auto_cpu_offload(self._auto_offload_device) + + return component_id - def remove(self, name): - if name not in self.components: - logger.warning(f"Component '{name}' not found in ComponentsManager") + + def remove(self, component_id: str = None): + + if component_id not in self.components: + logger.warning(f"Component '{component_id}' not found in ComponentsManager") return - - self.components.pop(name) - self.added_time.pop(name) + + component = self.components.pop(component_id) + self.added_time.pop(component_id) + + for collection in self.collections: + if component_id in self.collections[collection]: + self.collections[collection].remove(component_id) if self._auto_offload_enabled: self.enable_auto_cpu_offload(self._auto_offload_device) - - # YiYi TODO: looking into improving the search pattern - def get(self, names: Union[str, List[str]]): + else: + if isinstance(component, torch.nn.Module): + component.to("cpu") + del component + import gc + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + def get(self, names: Union[str, List[str]] = None, collection: Optional[str] = None, load_id: Optional[str] = None, + as_name_component_tuples: bool = False): """ - Get components by name with simple pattern matching. + Select components by name with simple pattern matching. Args: names: Component name(s) or pattern(s) Patterns: - - "unet" : exact match - - "!unet" : everything except exact match "unet" - - "base_*" : everything starting with "base_" - - "!base_*" : everything NOT starting with "base_" - - "*unet*" : anything containing "unet" - - "!*unet*" : anything NOT containing "unet" - - "refiner|vae|unet" : anything containing any of these terms - - "!refiner|vae|unet" : anything NOT containing any of these terms + - "unet" : match any component with base name "unet" (e.g., unet_123abc) + - "!unet" : everything except components with base name "unet" + - "unet*" : anything with base name starting with "unet" + - "!unet*" : anything with base name NOT starting with "unet" + - "*unet*" : anything with base name containing "unet" + - "!*unet*" : anything with base name NOT containing "unet" + - "refiner|vae|unet" : anything with base name exactly matching "refiner", "vae", or "unet" + - "!refiner|vae|unet" : anything with base name NOT exactly matching "refiner", "vae", or "unet" + - "unet*|vae*" : anything with base name starting with "unet" OR starting with "vae" + collection: Optional collection to filter by + load_id: Optional load_id to filter by + as_name_component_tuples: If True, returns a list of (name, component) tuples using base names + instead of a dictionary with component IDs as keys Returns: - Single component if names is str and matches one component, - dict of components if names matches multiple components or is a list + Dictionary mapping component IDs to components, + or list of (base_name, component) tuples if as_name_component_tuples=True """ + + selected_ids = self._lookup_ids(collection=collection, load_id=load_id) + components = {k: self.components[k] for k in selected_ids} + + # Helper to extract base name from component_id + def get_base_name(component_id): + parts = component_id.split('_') + # If the last part looks like a UUID, remove it + if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]: + return '_'.join(parts[:-1]) + return component_id + + if names is None: + if as_name_component_tuples: + return [(get_base_name(comp_id), comp) for comp_id, comp in components.items()] + else: + return components + + # Create mapping from component_id to base_name for all components + base_names = {comp_id: get_base_name(comp_id) for comp_id in components.keys()} + + def matches_pattern(component_id, pattern, exact_match=False): + """ + Helper function to check if a component matches a pattern based on its base name. + + Args: + component_id: The component ID to check + pattern: The pattern to match against + exact_match: If True, only exact matches to base_name are considered + """ + base_name = base_names[component_id] + + # Exact match with base name + if exact_match: + return pattern == base_name + + # Prefix match (ends with *) + elif pattern.endswith('*'): + prefix = pattern[:-1] + return base_name.startswith(prefix) + + # Contains match (starts with *) + elif pattern.startswith('*'): + search = pattern[1:-1] if pattern.endswith('*') else pattern[1:] + return search in base_name + + # Exact match (no wildcards) + else: + return pattern == base_name + if isinstance(names, str): # Check if this is a "not" pattern is_not_pattern = names.startswith('!') @@ -286,33 +445,45 @@ def get(self, names: Union[str, List[str]]): # Handle OR patterns (containing |) if '|' in names: terms = names.split('|') + matches = {} + + for comp_id, comp in components.items(): + # For OR patterns with exact names (no wildcards), we do exact matching on base names + exact_match = all(not (term.startswith('*') or term.endswith('*')) for term in terms) + + # Check if any of the terms match this component + should_include = any(matches_pattern(comp_id, term, exact_match) for term in terms) + + # Flip the decision if this is a NOT pattern + if is_not_pattern: + should_include = not should_include + + if should_include: + matches[comp_id] = comp + + log_msg = "NOT " if is_not_pattern else "" + match_type = "exactly matching" if exact_match else "matching any of patterns" + logger.info(f"Getting components {log_msg}{match_type} {terms}: {list(matches.keys())}") + + # Try exact match with a base name + elif any(names == base_name for base_name in base_names.values()): + # Find all components with this base name matches = { - name: comp for name, comp in self.components.items() - if any((term in name) != is_not_pattern for term in terms) # Flip condition if not pattern + comp_id: comp for comp_id, comp in components.items() + if (base_names[comp_id] == names) != is_not_pattern } + if is_not_pattern: - logger.info(f"Getting components NOT containing any of {terms}: {list(matches.keys())}") - else: - logger.info(f"Getting components containing any of {terms}: {list(matches.keys())}") - - # Exact match - elif names in self.components: - if is_not_pattern: - matches = { - name: comp for name, comp in self.components.items() - if name != names - } - logger.info(f"Getting all components except '{names}': {list(matches.keys())}") + logger.info(f"Getting all components except those with base name '{names}': {list(matches.keys())}") else: - logger.info(f"Getting component: {names}") - return self.components[names] + logger.info(f"Getting components with base name '{names}': {list(matches.keys())}") # Prefix match (ends with *) elif names.endswith('*'): prefix = names[:-1] matches = { - name: comp for name, comp in self.components.items() - if name.startswith(prefix) != is_not_pattern # Flip condition if not pattern + comp_id: comp for comp_id, comp in components.items() + if base_names[comp_id].startswith(prefix) != is_not_pattern } if is_not_pattern: logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}") @@ -323,31 +494,46 @@ def get(self, names: Union[str, List[str]]): elif names.startswith('*'): search = names[1:-1] if names.endswith('*') else names[1:] matches = { - name: comp for name, comp in self.components.items() - if (search in name) != is_not_pattern # Flip condition if not pattern + comp_id: comp for comp_id, comp in components.items() + if (search in base_names[comp_id]) != is_not_pattern } if is_not_pattern: logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}") else: logger.info(f"Getting components containing '{search}': {list(matches.keys())}") + # Substring match (no wildcards, but not an exact component name) + elif any(names in base_name for base_name in base_names.values()): + matches = { + comp_id: comp for comp_id, comp in components.items() + if (names in base_names[comp_id]) != is_not_pattern + } + if is_not_pattern: + logger.info(f"Getting components NOT containing '{names}': {list(matches.keys())}") + else: + logger.info(f"Getting components containing '{names}': {list(matches.keys())}") + else: - raise ValueError(f"Component '{names}' not found in ComponentsManager") + raise ValueError(f"Component or pattern '{names}' not found in ComponentsManager") if not matches: raise ValueError(f"No components found matching pattern '{names}'") - return matches if len(matches) > 1 else next(iter(matches.values())) + + if as_name_component_tuples: + return [(base_names[comp_id], comp) for comp_id, comp in matches.items()] + else: + return matches elif isinstance(names, list): results = {} for name in names: - result = self.get(name) - if isinstance(result, dict): - results.update(result) - else: - results[name] = result - logger.info(f"Getting multiple components: {list(results.keys())}") - return results + result = self.get(name, collection, load_id, as_name_component_tuples=False) + results.update(result) + + if as_name_component_tuples: + return [(base_names[comp_id], comp) for comp_id, comp in results.items()] + else: + return results else: raise ValueError(f"Invalid type for names: {type(names)}") @@ -391,11 +577,12 @@ def disable_auto_cpu_offload(self): self.model_hooks = None self._auto_offload_enabled = False - def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = None) -> Optional[Dict[str, Any]]: + # YiYi TODO: add quantization info + def get_model_info(self, component_id: str, fields: Optional[Union[str, List[str]]] = None) -> Optional[Dict[str, Any]]: """Get comprehensive information about a component. Args: - name: Name of the component to get info for + component_id: Name of the component to get info for fields: Optional field(s) to return. Can be a string for single field or list of fields. If None, returns all fields. @@ -404,23 +591,32 @@ def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = No If fields is specified, returns only those fields. If a single field is requested as string, returns just that field's value. """ - if name not in self.components: - raise ValueError(f"Component '{name}' not found in ComponentsManager") + if component_id not in self.components: + raise ValueError(f"Component '{component_id}' not found in ComponentsManager") - component = self.components[name] + component = self.components[component_id] # Build complete info dict first info = { - "model_id": name, - "added_time": self.added_time[name], + "model_id": component_id, + "added_time": self.added_time[component_id], + "collection": ", ".join([coll for coll, comps in self.collections.items() if component_id in comps]) or None, } # Additional info for torch.nn.Module components if isinstance(component, torch.nn.Module): + # Check for hook information + has_hook = hasattr(component, "_hf_hook") + execution_device = None + if has_hook and hasattr(component._hf_hook, "execution_device"): + execution_device = component._hf_hook.execution_device + info.update({ "class_name": component.__class__.__name__, "size_gb": get_memory_footprint(component) / (1024**3), "adapters": None, # Default to None + "has_hook": has_hook, + "execution_device": execution_device, }) # Get adapters if applicable @@ -454,12 +650,64 @@ def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = No return info def __repr__(self): + # Helper to get simple name without UUID + def get_simple_name(name): + # Extract the base name by splitting on underscore and taking first part + # This assumes names are in format "name_uuid" + parts = name.split('_') + # If we have at least 2 parts and the last part looks like a UUID, remove it + if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]: + return '_'.join(parts[:-1]) + return name + + # Extract load_id if available + def get_load_id(component): + if hasattr(component, "_diffusers_load_id"): + return component._diffusers_load_id + return "N/A" + + # Format device info compactly + def format_device(component, info): + if not info["has_hook"]: + return str(getattr(component, 'device', 'N/A')) + else: + device = str(getattr(component, 'device', 'N/A')) + exec_device = str(info['execution_device'] or 'N/A') + return f"{device}({exec_device})" + + # Get all simple names to calculate width + simple_names = [get_simple_name(id) for id in self.components.keys()] + + # Get max length of load_ids for models + load_ids = [ + get_load_id(component) + for component in self.components.values() + if isinstance(component, torch.nn.Module) and hasattr(component, "_diffusers_load_id") + ] + max_load_id_len = max([15] + [len(str(lid)) for lid in load_ids]) if load_ids else 15 + + # Get all collections for each component + component_collections = {} + for name in self.components.keys(): + component_collections[name] = [] + for coll, comps in self.collections.items(): + if name in comps: + component_collections[name].append(coll) + if not component_collections[name]: + component_collections[name] = ["N/A"] + + # Find the maximum collection name length + all_collections = [coll for colls in component_collections.values() for coll in colls] + max_collection_len = max(10, max(len(str(c)) for c in all_collections)) if all_collections else 10 + col_widths = { - "id": max(15, max(len(id) for id in self.components.keys())), + "name": max(15, max(len(name) for name in simple_names)), "class": max(25, max(len(component.__class__.__name__) for component in self.components.values())), - "device": 10, + "device": 15, # Reduced since using more compact format "dtype": 15, "size": 10, + "load_id": max_load_id_len, + "collection": max_collection_len } # Create the header lines @@ -476,17 +724,33 @@ def __repr__(self): if models: output += "Models:\n" + dash_line # Column headers - output += f"{'Model ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}} | " - output += f"{'Device':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | Size (GB)\n" + output += f"{'Name':<{col_widths['name']}} | {'Class':<{col_widths['class']}} | " + output += f"{'Device':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | " + output += f"{'Size (GB)':<{col_widths['size']}} | {'Load ID':<{col_widths['load_id']}} | Collection\n" output += dash_line # Model entries for name, component in models.items(): info = self.get_model_info(name) - device = str(getattr(component, "device", "N/A")) + simple_name = get_simple_name(name) + device_str = format_device(component, info) dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A" - output += f"{name:<{col_widths['id']}} | {info['class_name']:<{col_widths['class']}} | " - output += f"{device:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | {info['size_gb']:.2f}\n" + load_id = get_load_id(component) + + # Print first collection on the main line + first_collection = component_collections[name][0] if component_collections[name] else "N/A" + + output += f"{simple_name:<{col_widths['name']}} | {info['class_name']:<{col_widths['class']}} | " + output += f"{device_str:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | " + output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {first_collection}\n" + + # Print additional collections on separate lines if they exist + for i in range(1, len(component_collections[name])): + collection = component_collections[name][i] + output += f"{'':<{col_widths['name']}} | {'':<{col_widths['class']}} | " + output += f"{'':<{col_widths['device']}} | {'':<{col_widths['dtype']}} | " + output += f"{'':<{col_widths['size']}} | {'':<{col_widths['load_id']}} | {collection}\n" + output += dash_line # Other components section @@ -495,12 +759,24 @@ def __repr__(self): output += "\n" output += "Other Components:\n" + dash_line # Column headers for other components - output += f"{'Component ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}}\n" + output += f"{'Name':<{col_widths['name']}} | {'Class':<{col_widths['class']}} | Collection\n" output += dash_line # Other component entries for name, component in others.items(): - output += f"{name:<{col_widths['id']}} | {component.__class__.__name__:<{col_widths['class']}}\n" + info = self.get_model_info(name) + simple_name = get_simple_name(name) + + # Print first collection on the main line + first_collection = component_collections[name][0] if component_collections[name] else "N/A" + + output += f"{simple_name:<{col_widths['name']}} | {component.__class__.__name__:<{col_widths['class']}} | {first_collection}\n" + + # Print additional collections on separate lines if they exist + for i in range(1, len(component_collections[name])): + collection = component_collections[name][i] + output += f"{'':<{col_widths['name']}} | {'':<{col_widths['class']}} | {collection}\n" + output += dash_line # Add additional component info @@ -508,7 +784,8 @@ def __repr__(self): for name in self.components: info = self.get_model_info(name) if info is not None and (info.get("adapters") is not None or info.get("ip_adapter")): - output += f"\n{name}:\n" + simple_name = get_simple_name(name) + output += f"\n{simple_name}:\n" if info.get("adapters") is not None: output += f" Adapters: {info['adapters']}\n" if info.get("ip_adapter"): @@ -517,7 +794,7 @@ def __repr__(self): return output - def add_from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs): + def from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs): """ Load components from a pretrained model and add them to the manager. @@ -527,17 +804,12 @@ def add_from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[st If provided, components will be named as "{prefix}_{component_name}" **kwargs: Additional arguments to pass to DiffusionPipeline.from_pretrained() """ - from ..pipelines.pipeline_utils import DiffusionPipeline - - pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs) - for name, component in pipe.components.items(): - - if component is None: - continue - - # Add prefix if specified - component_name = f"{prefix}_{name}" if prefix else name - + subfolder = kwargs.pop("subfolder", None) + # YiYi TODO: extend AutoModel to support non-diffusers models + if subfolder: + from ..models import AutoModel + component = AutoModel.from_pretrained(pretrained_model_name_or_path, subfolder=subfolder, **kwargs) + component_name = f"{prefix}_{subfolder}" if prefix else subfolder if component_name not in self.components: self.add(component_name, component) else: @@ -546,6 +818,59 @@ def add_from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[st f"1. remove the existing component with remove('{component_name}')\n" f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')" ) + else: + from ..pipelines.pipeline_utils import DiffusionPipeline + pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs) + for name, component in pipe.components.items(): + + if component is None: + continue + + # Add prefix if specified + component_name = f"{prefix}_{name}" if prefix else name + + if component_name not in self.components: + self.add(component_name, component) + else: + logger.warning( + f"Component '{component_name}' already exists in ComponentsManager and will not be added. To add it, either:\n" + f"1. remove the existing component with remove('{component_name}')\n" + f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')" + ) + + def get_one(self, component_id: Optional[str] = None, name: Optional[str] = None, collection: Optional[str] = None, load_id: Optional[str] = None) -> Any: + """ + Get a single component by name. Raises an error if multiple components match or none are found. + + Args: + name: Component name or pattern + collection: Optional collection to filter by + load_id: Optional load_id to filter by + + Returns: + A single component + + Raises: + ValueError: If no components match or multiple components match + """ + + # if component_id is provided, return the component + if component_id is not None and (name is not None or collection is not None or load_id is not None): + raise ValueError(" if component_id is provided, name, collection, and load_id must be None") + elif component_id is not None: + if component_id not in self.components: + raise ValueError(f"Component '{component_id}' not found in ComponentsManager") + return self.components[component_id] + + results = self.get(name, collection, load_id) + + if not results: + raise ValueError(f"No components found matching '{name}'") + + if len(results) > 1: + raise ValueError(f"Multiple components found matching '{name}': {list(results.keys())}") + + return next(iter(results.values())) def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]: """Summarizes a dictionary by finding common prefixes that share the same value. diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py new file mode 100644 index 000000000000..84b9b594d758 --- /dev/null +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -0,0 +1,2247 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect + + +import traceback +import warnings +from collections import OrderedDict +from dataclasses import dataclass, field +from typing import Any, Dict, List, Tuple, Union, Optional +from copy import deepcopy + + +import torch +from tqdm.auto import tqdm +import re +import os +import importlib + +from huggingface_hub.utils import validate_hf_hub_args + +from ..configuration_utils import ConfigMixin, FrozenDict +from ..utils import ( + is_accelerate_available, + logging, + PushToHubMixin, +) +from ..pipelines.pipeline_loading_utils import simple_get_class_obj, _fetch_class_library_tuple +from .modular_pipeline_utils import ( + ComponentSpec, + ConfigSpec, + InputParam, + OutputParam, + format_components, + format_configs, + format_inputs_short, + format_intermediates_short, + make_doc_string, +) +from .components_manager import ComponentsManager +from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code + +from copy import deepcopy +if is_accelerate_available(): + import accelerate + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +MODULAR_LOADER_MAPPING = OrderedDict( + [ + ("stable-diffusion-xl", "StableDiffusionXLModularLoader"), + ] +) + + +@dataclass +class PipelineState: + """ + [`PipelineState`] stores the state of a pipeline. It is used to pass data between pipeline blocks. + """ + + inputs: Dict[str, Any] = field(default_factory=dict) + intermediates: Dict[str, Any] = field(default_factory=dict) + input_kwargs: Dict[str, list[str, Any]] = field(default_factory=dict) + intermediate_kwargs: Dict[str, list[str, Any]] = field(default_factory=dict) + + def add_input(self, key: str, value: Any, kwargs_type: str = None): + """ + Add an input to the pipeline state with optional metadata. + + Args: + key (str): The key for the input + value (Any): The input value + kwargs_type (str): The kwargs_type to store with the input + """ + self.inputs[key] = value + if kwargs_type is not None: + if kwargs_type not in self.input_kwargs: + self.input_kwargs[kwargs_type] = [key] + else: + self.input_kwargs[kwargs_type].append(key) + + def add_intermediate(self, key: str, value: Any, kwargs_type: str = None): + """ + Add an intermediate value to the pipeline state with optional metadata. + + Args: + key (str): The key for the intermediate value + value (Any): The intermediate value + kwargs_type (str): The kwargs_type to store with the intermediate value + """ + self.intermediates[key] = value + if kwargs_type is not None: + if kwargs_type not in self.intermediate_kwargs: + self.intermediate_kwargs[kwargs_type] = [key] + else: + self.intermediate_kwargs[kwargs_type].append(key) + + def get_input(self, key: str, default: Any = None) -> Any: + value = self.inputs.get(key, default) + if value is not None: + return deepcopy(value) + + def get_inputs(self, keys: List[str], default: Any = None) -> Dict[str, Any]: + return {key: self.inputs.get(key, default) for key in keys} + + def get_inputs_kwargs(self, kwargs_type: str) -> Dict[str, Any]: + """ + Get all inputs with matching kwargs_type. + + Args: + kwargs_type (str): The kwargs_type to filter by + + Returns: + Dict[str, Any]: Dictionary of inputs with matching kwargs_type + """ + input_names = self.input_kwargs.get(kwargs_type, []) + return self.get_inputs(input_names) + + def get_intermediates_kwargs(self, kwargs_type: str) -> Dict[str, Any]: + """ + Get all intermediates with matching kwargs_type. + + Args: + kwargs_type (str): The kwargs_type to filter by + + Returns: + Dict[str, Any]: Dictionary of intermediates with matching kwargs_type + """ + intermediate_names = self.intermediate_kwargs.get(kwargs_type, []) + return self.get_intermediates(intermediate_names) + + def get_intermediate(self, key: str, default: Any = None) -> Any: + return self.intermediates.get(key, default) + + def get_intermediates(self, keys: List[str], default: Any = None) -> Dict[str, Any]: + return {key: self.intermediates.get(key, default) for key in keys} + + def to_dict(self) -> Dict[str, Any]: + return {**self.__dict__, "inputs": self.inputs, "intermediates": self.intermediates} + + def __repr__(self): + def format_value(v): + if hasattr(v, "shape") and hasattr(v, "dtype"): + return f"Tensor(dtype={v.dtype}, shape={v.shape})" + elif isinstance(v, list) and len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): + return f"[Tensor(dtype={v[0].dtype}, shape={v[0].shape}), ...]" + else: + return repr(v) + + inputs = "\n".join(f" {k}: {format_value(v)}" for k, v in self.inputs.items()) + intermediates = "\n".join(f" {k}: {format_value(v)}" for k, v in self.intermediates.items()) + + # Format input_kwargs and intermediate_kwargs + input_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.input_kwargs.items()) + intermediate_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.intermediate_kwargs.items()) + + return ( + f"PipelineState(\n" + f" inputs={{\n{inputs}\n }},\n" + f" intermediates={{\n{intermediates}\n }},\n" + f" input_kwargs={{\n{input_kwargs_str}\n }},\n" + f" intermediate_kwargs={{\n{intermediate_kwargs_str}\n }}\n" + f")" + ) + + +@dataclass +class BlockState: + """ + Container for block state data with attribute access and formatted representation. + """ + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + def __getitem__(self, key: str): + # allows block_state["foo"] + return getattr(self, key, None) + + def __setitem__(self, key: str, value: Any): + # allows block_state["foo"] = "bar" + setattr(self, key, value) + + def as_dict(self): + """ + Convert BlockState to a dictionary. + + Returns: + Dict[str, Any]: Dictionary containing all attributes of the BlockState + """ + return {key: value for key, value in self.__dict__.items()} + + def __repr__(self): + def format_value(v): + # Handle tensors directly + if hasattr(v, "shape") and hasattr(v, "dtype"): + return f"Tensor(dtype={v.dtype}, shape={v.shape})" + + # Handle lists of tensors + elif isinstance(v, list): + if len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): + shapes = [t.shape for t in v] + return f"List[{len(v)}] of Tensors with shapes {shapes}" + return repr(v) + + # Handle tuples of tensors + elif isinstance(v, tuple): + if len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): + shapes = [t.shape for t in v] + return f"Tuple[{len(v)}] of Tensors with shapes {shapes}" + return repr(v) + + # Handle dicts with tensor values + elif isinstance(v, dict): + formatted_dict = {} + for k, val in v.items(): + if hasattr(val, "shape") and hasattr(val, "dtype"): + formatted_dict[k] = f"Tensor(shape={val.shape}, dtype={val.dtype})" + elif isinstance(val, list) and len(val) > 0 and hasattr(val[0], "shape") and hasattr(val[0], "dtype"): + shapes = [t.shape for t in val] + formatted_dict[k] = f"List[{len(val)}] of Tensors with shapes {shapes}" + else: + formatted_dict[k] = repr(val) + return formatted_dict + + # Default case + return repr(v) + + attributes = "\n".join(f" {k}: {format_value(v)}" for k, v in self.__dict__.items()) + return f"BlockState(\n{attributes}\n)" + + +class ModularPipelineBlocks(ConfigMixin): + """ + Mixin for all PipelineBlocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks + """ + + config_name = "config.json" + + @classmethod + def _get_signature_keys(cls, obj): + parameters = inspect.signature(obj.__init__).parameters + required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} + optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) + expected_modules = set(required_parameters.keys()) - {"self"} + + return expected_modules, optional_parameters + + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + trust_remote_code: Optional[bool] = None, + **kwargs, + ): + hub_kwargs_names = [ + "cache_dir", + "force_download", + "local_files_only", + "proxies", + "resume_download", + "revision", + "subfolder", + "token", + ] + hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs} + + config = cls.load_config(pretrained_model_name_or_path) + has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"] + trust_remote_code = resolve_trust_remote_code( + trust_remote_code, pretrained_model_name_or_path, has_remote_code + ) + if not (has_remote_code and trust_remote_code): + raise ValueError("TODO") + + class_ref = config["auto_map"][cls.__name__] + module_file, class_name = class_ref.split(".") + module_file = module_file + ".py" + block_cls = get_class_from_dynamic_module( + pretrained_model_name_or_path, + module_file=module_file, + class_name=class_name, + is_modular=True, + **hub_kwargs, + **kwargs, + ) + expected_kwargs, optional_kwargs = block_cls._get_signature_keys(block_cls) + block_kwargs = { + name: kwargs.pop(name) for name in kwargs if name in expected_kwargs or name in optional_kwargs + } + + return block_cls(**block_kwargs) + + def init_pipeline(self, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None): + """ + create a ModularLoader, optionally accept modular_repo to load from hub. + """ + loader_class_name = MODULAR_LOADER_MAPPING.get(self.model_name, ModularLoader.__name__) + diffusers_module = importlib.import_module("diffusers") + loader_class = getattr(diffusers_module, loader_class_name) + + # Create deep copies to avoid modifying the original specs + component_specs = deepcopy(self.expected_components) + config_specs = deepcopy(self.expected_configs) + # Create the loader with the updated specs + specs = component_specs + config_specs + + loader = loader_class(specs=specs, pretrained_model_name_or_path=pretrained_model_name_or_path, component_manager=component_manager, collection=collection) + modular_pipeline = ModularPipeline(blocks=self, loader=loader) + return modular_pipeline + + +class PipelineBlock(ModularPipelineBlocks): + + model_name = None + + @property + def description(self) -> str: + """Description of the block. Must be implemented by subclasses.""" + # raise NotImplementedError("description method must be implemented in subclasses") + return "TODO: add a description" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [] + + @property + def expected_configs(self) -> List[ConfigSpec]: + return [] + + + @property + def inputs(self) -> List[InputParam]: + """List of input parameters. Must be implemented by subclasses.""" + return [] + + @property + def intermediates_inputs(self) -> List[InputParam]: + """List of intermediate input parameters. Must be implemented by subclasses.""" + return [] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + """List of intermediate output parameters. Must be implemented by subclasses.""" + return [] + + def _get_outputs(self): + return self.intermediates_outputs + + # YiYi TODO: is it too easy for user to unintentionally override these properties? + # Adding outputs attributes here for consistency between PipelineBlock/AutoPipelineBlocks/SequentialPipelineBlocks + @property + def outputs(self) -> List[OutputParam]: + return self._get_outputs() + + def _get_required_inputs(self): + input_names = [] + for input_param in self.inputs: + if input_param.required: + input_names.append(input_param.name) + return input_names + + @property + def required_inputs(self) -> List[str]: + return self._get_required_inputs() + + + def _get_required_intermediates_inputs(self): + input_names = [] + for input_param in self.intermediates_inputs: + if input_param.required: + input_names.append(input_param.name) + return input_names + + # YiYi TODO: maybe we do not need this, it is only used in docstring, + # intermediate_inputs is by default required, unless you manually handle it inside the block + @property + def required_intermediates_inputs(self) -> List[str]: + return self._get_required_intermediates_inputs() + + + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + raise NotImplementedError("__call__ method must be implemented in subclasses") + + def __repr__(self): + class_name = self.__class__.__name__ + base_class = self.__class__.__bases__[0].__name__ + + # Format description with proper indentation + desc_lines = self.description.split('\n') + desc = [] + # First line with "Description:" label + desc.append(f" Description: {desc_lines[0]}") + # Subsequent lines with proper indentation + if len(desc_lines) > 1: + desc.extend(f" {line}" for line in desc_lines[1:]) + desc = '\n'.join(desc) + '\n' + + # Components section - use format_components with add_empty_lines=False + expected_components = getattr(self, "expected_components", []) + components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) + components = " " + components_str.replace("\n", "\n ") + + # Configs section - use format_configs with add_empty_lines=False + expected_configs = getattr(self, "expected_configs", []) + configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) + configs = " " + configs_str.replace("\n", "\n ") + + # Inputs section + inputs_str = format_inputs_short(self.inputs) + inputs = "Inputs:\n " + inputs_str + + # Intermediates section + intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs) + intermediates = f"Intermediates:\n{intermediates_str}" + + return ( + f"{class_name}(\n" + f" Class: {base_class}\n" + f"{desc}" + f"{components}\n" + f"{configs}\n" + f" {inputs}\n" + f" {intermediates}\n" + f")" + ) + + + @property + def doc(self): + return make_doc_string( + self.inputs, + self.intermediates_inputs, + self.outputs, + self.description, + class_name=self.__class__.__name__, + expected_components=self.expected_components, + expected_configs=self.expected_configs + ) + + + # YiYi TODO: input and inteermediate inputs with same name? should warn? + def get_block_state(self, state: PipelineState) -> dict: + """Get all inputs and intermediates in one dictionary""" + data = {} + + # Check inputs + for input_param in self.inputs: + if input_param.name: + value = state.get_input(input_param.name) + if input_param.required and value is None: + raise ValueError(f"Required input '{input_param.name}' is missing") + elif value is not None or (value is None and input_param.name not in data): + data[input_param.name] = value + elif input_param.kwargs_type: + # if kwargs_type is provided, get all inputs with matching kwargs_type + if input_param.kwargs_type not in data: + data[input_param.kwargs_type] = {} + inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type) + if inputs_kwargs: + for k, v in inputs_kwargs.items(): + if v is not None: + data[k] = v + data[input_param.kwargs_type][k] = v + + # Check intermediates + for input_param in self.intermediates_inputs: + if input_param.name: + value = state.get_intermediate(input_param.name) + if input_param.required and value is None: + raise ValueError(f"Required intermediate input '{input_param.name}' is missing") + elif value is not None or (value is None and input_param.name not in data): + data[input_param.name] = value + elif input_param.kwargs_type: + # if kwargs_type is provided, get all intermediates with matching kwargs_type + if input_param.kwargs_type not in data: + data[input_param.kwargs_type] = {} + intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type) + if intermediates_kwargs: + for k, v in intermediates_kwargs.items(): + if v is not None: + if k not in data: + data[k] = v + data[input_param.kwargs_type][k] = v + return BlockState(**data) + + def add_block_state(self, state: PipelineState, block_state: BlockState): + for output_param in self.intermediates_outputs: + if not hasattr(block_state, output_param.name): + raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") + param = getattr(block_state, output_param.name) + state.add_intermediate(output_param.name, param, output_param.kwargs_type) + + for input_param in self.intermediates_inputs: + if hasattr(block_state, input_param.name): + param = getattr(block_state, input_param.name) + # Only add if the value is different from what's in the state + current_value = state.get_intermediate(input_param.name) + if current_value is not param: # Using identity comparison to check if object was modified + state.add_intermediate(input_param.name, param, input_param.kwargs_type) + + for input_param in self.intermediates_inputs: + if input_param.name and hasattr(block_state, input_param.name): + param = getattr(block_state, input_param.name) + # Only add if the value is different from what's in the state + current_value = state.get_intermediate(input_param.name) + if current_value is not param: # Using identity comparison to check if object was modified + state.add_intermediate(input_param.name, param, input_param.kwargs_type) + elif input_param.kwargs_type: + # if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters + # we need to first find out which inputs are and loop through them. + intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type) + for param_name, current_value in intermediates_kwargs.items(): + param = getattr(block_state, param_name) + if current_value is not param: # Using identity comparison to check if object was modified + state.add_intermediate(param_name, param, input_param.kwargs_type) + + def save_pretrained(self, save_directory, push_to_hub = False, **kwargs): + # TODO: factor out this logic. + cls_name = self.__class__.__name__ + + full_mod = type(self).__module__ + module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "") + parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0] + auto_map = {f"{parent_module}": f"{module}.{cls_name}"} + _component_names = [c.name for c in self.expected_components] + + self.register_to_config(auto_map=auto_map, _component_names=_component_names) + self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) + config = dict(self.config) + self._internal_dict = FrozenDict(config) + + +def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]: + """ + Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if + current default value is None and new default value is not None. Warns if multiple non-None default values + exist for the same input. + + Args: + named_input_lists: List of tuples containing (block_name, input_param_list) pairs + + Returns: + List[InputParam]: Combined list of unique InputParam objects + """ + combined_dict = {} # name -> InputParam + value_sources = {} # name -> block_name + + for block_name, inputs in named_input_lists: + for input_param in inputs: + if input_param.name is None and input_param.kwargs_type is not None: + input_name = "*_" + input_param.kwargs_type + else: + input_name = input_param.name + if input_name in combined_dict: + current_param = combined_dict[input_name] + if (current_param.default is not None and + input_param.default is not None and + current_param.default != input_param.default): + warnings.warn( + f"Multiple different default values found for input '{input_name}': " + f"{current_param.default} (from block '{value_sources[input_name]}') and " + f"{input_param.default} (from block '{block_name}'). Using {current_param.default}." + ) + if current_param.default is None and input_param.default is not None: + combined_dict[input_name] = input_param + value_sources[input_name] = block_name + else: + combined_dict[input_name] = input_param + value_sources[input_name] = block_name + + return list(combined_dict.values()) + +def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> List[OutputParam]: + """ + Combines multiple lists of OutputParam objects from different blocks. For duplicate outputs, + keeps the first occurrence of each output name. + + Args: + named_output_lists: List of tuples containing (block_name, output_param_list) pairs + + Returns: + List[OutputParam]: Combined list of unique OutputParam objects + """ + combined_dict = {} # name -> OutputParam + + for block_name, outputs in named_output_lists: + for output_param in outputs: + if (output_param.name not in combined_dict) or (combined_dict[output_param.name].kwargs_type is None and output_param.kwargs_type is not None): + combined_dict[output_param.name] = output_param + + return list(combined_dict.values()) + + +class AutoPipelineBlocks(ModularPipelineBlocks): + """ + A class that automatically selects a block to run based on the inputs. + + Attributes: + block_classes: List of block classes to be used + block_names: List of prefixes for each block + block_trigger_inputs: List of input names that trigger specific blocks, with None for default + """ + + block_classes = [] + block_names = [] + block_trigger_inputs = [] + + def __init__(self): + blocks = OrderedDict() + for block_name, block_cls in zip(self.block_names, self.block_classes): + blocks[block_name] = block_cls() + self.blocks = blocks + if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)): + raise ValueError(f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same.") + default_blocks = [t for t in self.block_trigger_inputs if t is None] + # can only have 1 or 0 default block, and has to put in the last + # the order of blocksmatters here because the first block with matching trigger will be dispatched + # e.g. blocks = [inpaint, img2img] and block_trigger_inputs = ["mask", "image"] + # if both mask and image are provided, it is inpaint; if only image is provided, it is img2img + if len(default_blocks) > 1 or ( + len(default_blocks) == 1 and self.block_trigger_inputs[-1] is not None + ): + raise ValueError( + f"In {self.__class__.__name__}, exactly one None must be specified as the last element " + "in block_trigger_inputs." + ) + + # Map trigger inputs to block objects + self.trigger_to_block_map = dict(zip(self.block_trigger_inputs, self.blocks.values())) + self.trigger_to_block_name_map = dict(zip(self.block_trigger_inputs, self.blocks.keys())) + self.block_to_trigger_map = dict(zip(self.blocks.keys(), self.block_trigger_inputs)) + + @property + def model_name(self): + return next(iter(self.blocks.values())).model_name + + @property + def description(self): + return "" + + @property + def expected_components(self): + expected_components = [] + for block in self.blocks.values(): + for component in block.expected_components: + if component not in expected_components: + expected_components.append(component) + return expected_components + + @property + def expected_configs(self): + expected_configs = [] + for block in self.blocks.values(): + for config in block.expected_configs: + if config not in expected_configs: + expected_configs.append(config) + return expected_configs + + + @property + def required_inputs(self) -> List[str]: + if None not in self.block_trigger_inputs: + return [] + first_block = next(iter(self.blocks.values())) + required_by_all = set(getattr(first_block, "required_inputs", set())) + + # Intersect with required inputs from all other blocks + for block in list(self.blocks.values())[1:]: + block_required = set(getattr(block, "required_inputs", set())) + required_by_all.intersection_update(block_required) + + return list(required_by_all) + + # YiYi TODO: maybe we do not need this, it is only used in docstring, + # intermediate_inputs is by default required, unless you manually handle it inside the block + @property + def required_intermediates_inputs(self) -> List[str]: + if None not in self.block_trigger_inputs: + return [] + first_block = next(iter(self.blocks.values())) + required_by_all = set(getattr(first_block, "required_intermediates_inputs", set())) + + # Intersect with required inputs from all other blocks + for block in list(self.blocks.values())[1:]: + block_required = set(getattr(block, "required_intermediates_inputs", set())) + required_by_all.intersection_update(block_required) + + return list(required_by_all) + + + # YiYi TODO: add test for this + @property + def inputs(self) -> List[Tuple[str, Any]]: + named_inputs = [(name, block.inputs) for name, block in self.blocks.items()] + combined_inputs = combine_inputs(*named_inputs) + # mark Required inputs only if that input is required by all the blocks + for input_param in combined_inputs: + if input_param.name in self.required_inputs: + input_param.required = True + else: + input_param.required = False + return combined_inputs + + + @property + def intermediates_inputs(self) -> List[str]: + named_inputs = [(name, block.intermediates_inputs) for name, block in self.blocks.items()] + combined_inputs = combine_inputs(*named_inputs) + # mark Required inputs only if that input is required by all the blocks + for input_param in combined_inputs: + if input_param.name in self.required_intermediates_inputs: + input_param.required = True + else: + input_param.required = False + return combined_inputs + + @property + def intermediates_outputs(self) -> List[str]: + named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] + combined_outputs = combine_outputs(*named_outputs) + return combined_outputs + + @property + def outputs(self) -> List[str]: + named_outputs = [(name, block.outputs) for name, block in self.blocks.items()] + combined_outputs = combine_outputs(*named_outputs) + return combined_outputs + + @torch.no_grad() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + # Find default block first (if any) + + block = self.trigger_to_block_map.get(None) + for input_name in self.block_trigger_inputs: + if input_name is not None and state.get_input(input_name) is not None: + block = self.trigger_to_block_map[input_name] + break + elif input_name is not None and state.get_intermediate(input_name) is not None: + block = self.trigger_to_block_map[input_name] + break + + if block is None: + logger.warning(f"skipping auto block: {self.__class__.__name__}") + return pipeline, state + + try: + logger.info(f"Running block: {block.__class__.__name__}, trigger: {input_name}") + return block(pipeline, state) + except Exception as e: + error_msg = ( + f"\nError in block: {block.__class__.__name__}\n" + f"Error details: {str(e)}\n" + f"Traceback:\n{traceback.format_exc()}" + ) + logger.error(error_msg) + raise + + def _get_trigger_inputs(self): + """ + Returns a set of all unique trigger input values found in the blocks. + Returns: Set[str] containing all unique block_trigger_inputs values + """ + def fn_recursive_get_trigger(blocks): + trigger_values = set() + + if blocks is not None: + for name, block in blocks.items(): + # Check if current block has trigger inputs(i.e. auto block) + if hasattr(block, 'block_trigger_inputs') and block.block_trigger_inputs is not None: + # Add all non-None values from the trigger inputs list + trigger_values.update(t for t in block.block_trigger_inputs if t is not None) + + # If block has blocks, recursively check them + if hasattr(block, 'blocks'): + nested_triggers = fn_recursive_get_trigger(block.blocks) + trigger_values.update(nested_triggers) + + return trigger_values + + trigger_inputs = set(self.block_trigger_inputs) + trigger_inputs.update(fn_recursive_get_trigger(self.blocks)) + + return trigger_inputs + + @property + def trigger_inputs(self): + return self._get_trigger_inputs() + + def __repr__(self): + class_name = self.__class__.__name__ + base_class = self.__class__.__bases__[0].__name__ + header = ( + f"{class_name}(\n Class: {base_class}\n" + if base_class and base_class != "object" + else f"{class_name}(\n" + ) + + + if self.trigger_inputs: + header += "\n" + header += " " + "=" * 100 + "\n" + header += " This pipeline contains blocks that are selected at runtime based on inputs.\n" + header += f" Trigger Inputs: {self.trigger_inputs}\n" + # Get first trigger input as example + example_input = next(t for t in self.trigger_inputs if t is not None) + header += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n" + header += " " + "=" * 100 + "\n\n" + + # Format description with proper indentation + desc_lines = self.description.split('\n') + desc = [] + # First line with "Description:" label + desc.append(f" Description: {desc_lines[0]}") + # Subsequent lines with proper indentation + if len(desc_lines) > 1: + desc.extend(f" {line}" for line in desc_lines[1:]) + desc = '\n'.join(desc) + '\n' + + # Components section - focus only on expected components + expected_components = getattr(self, "expected_components", []) + components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) + + # Configs section - use format_configs with add_empty_lines=False + expected_configs = getattr(self, "expected_configs", []) + configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) + + # Blocks section - moved to the end with simplified format + blocks_str = " Blocks:\n" + for i, (name, block) in enumerate(self.blocks.items()): + # Get trigger input for this block + trigger = None + if hasattr(self, 'block_to_trigger_map'): + trigger = self.block_to_trigger_map.get(name) + # Format the trigger info + if trigger is None: + trigger_str = "[default]" + elif isinstance(trigger, (list, tuple)): + trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]" + else: + trigger_str = f"[trigger: {trigger}]" + # For AutoPipelineBlocks, add bullet points + blocks_str += f" • {name} {trigger_str} ({block.__class__.__name__})\n" + else: + # For SequentialPipelineBlocks, show execution order + blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" + + # Add block description + desc_lines = block.description.split('\n') + indented_desc = desc_lines[0] + if len(desc_lines) > 1: + indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) + blocks_str += f" Description: {indented_desc}\n\n" + + # Build the representation with conditional sections + result = f"{header}\n{desc}" + + # Only add components section if it has content + if components_str.strip(): + result += f"\n\n{components_str}" + + # Only add configs section if it has content + if configs_str.strip(): + result += f"\n\n{configs_str}" + + # Always add blocks section + result += f"\n\n{blocks_str})" + + return result + + + @property + def doc(self): + return make_doc_string( + self.inputs, + self.intermediates_inputs, + self.outputs, + self.description, + class_name=self.__class__.__name__, + expected_components=self.expected_components, + expected_configs=self.expected_configs + ) + + +class SequentialPipelineBlocks(ModularPipelineBlocks): + """ + A class that combines multiple pipeline block classes into one. When called, it will call each block in sequence. + """ + block_classes = [] + block_names = [] + + + @property + def description(self): + return "" + + @property + def model_name(self): + return next(iter(self.blocks.values())).model_name + + + @property + def expected_components(self): + expected_components = [] + for block in self.blocks.values(): + for component in block.expected_components: + if component not in expected_components: + expected_components.append(component) + return expected_components + + @property + def expected_configs(self): + expected_configs = [] + for block in self.blocks.values(): + for config in block.expected_configs: + if config not in expected_configs: + expected_configs.append(config) + return expected_configs + + @classmethod + def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "SequentialPipelineBlocks": + """Creates a SequentialPipelineBlocks instance from a dictionary of blocks. + + Args: + blocks_dict: Dictionary mapping block names to block classes or instances + + Returns: + A new SequentialPipelineBlocks instance + """ + instance = cls() + + # Create instances if classes are provided + blocks = {} + for name, block in blocks_dict.items(): + if inspect.isclass(block): + blocks[name] = block() + else: + blocks[name] = block + + instance.block_classes = [block.__class__ for block in blocks.values()] + instance.block_names = list(blocks.keys()) + instance.blocks = blocks + return instance + + def __init__(self): + blocks = OrderedDict() + for block_name, block_cls in zip(self.block_names, self.block_classes): + blocks[block_name] = block_cls() + self.blocks = blocks + + + @property + def required_inputs(self) -> List[str]: + # Get the first block from the dictionary + first_block = next(iter(self.blocks.values())) + required_by_any = set(getattr(first_block, "required_inputs", set())) + + # Union with required inputs from all other blocks + for block in list(self.blocks.values())[1:]: + block_required = set(getattr(block, "required_inputs", set())) + required_by_any.update(block_required) + + return list(required_by_any) + + # YiYi TODO: maybe we do not need this, it is only used in docstring, + # intermediate_inputs is by default required, unless you manually handle it inside the block + @property + def required_intermediates_inputs(self) -> List[str]: + required_intermediates_inputs = [] + for input_param in self.intermediates_inputs: + if input_param.required: + required_intermediates_inputs.append(input_param.name) + return required_intermediates_inputs + + # YiYi TODO: add test for this + @property + def inputs(self) -> List[Tuple[str, Any]]: + return self.get_inputs() + + def get_inputs(self): + named_inputs = [(name, block.inputs) for name, block in self.blocks.items()] + combined_inputs = combine_inputs(*named_inputs) + # mark Required inputs only if that input is required any of the blocks + for input_param in combined_inputs: + if input_param.name in self.required_inputs: + input_param.required = True + else: + input_param.required = False + return combined_inputs + + @property + def intermediates_inputs(self) -> List[str]: + return self.get_intermediates_inputs() + + def get_intermediates_inputs(self): + inputs = [] + outputs = set() + added_inputs = set() + + # Go through all blocks in order + for block in self.blocks.values(): + # Add inputs that aren't in outputs yet + for inp in block.intermediates_inputs: + if inp.name not in outputs and inp.name not in added_inputs: + inputs.append(inp) + added_inputs.add(inp.name) + + # Only add outputs if the block cannot be skipped + should_add_outputs = True + if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs: + should_add_outputs = False + + if should_add_outputs: + # Add this block's outputs + block_intermediates_outputs = [out.name for out in block.intermediates_outputs] + outputs.update(block_intermediates_outputs) + return inputs + + @property + def intermediates_outputs(self) -> List[str]: + named_outputs = [] + for name, block in self.blocks.items(): + inp_names = set([inp.name for inp in block.intermediates_inputs]) + # so we only need to list new variables as intermediates_outputs, but if user wants to list these they modified it's still fine (a.k.a we don't enforce) + # filter out them here so they do not end up as intermediates_outputs + if name not in inp_names: + named_outputs.append((name, block.intermediates_outputs)) + combined_outputs = combine_outputs(*named_outputs) + return combined_outputs + + # YiYi TODO: I think we can remove the outputs property + @property + def outputs(self) -> List[str]: + # return next(reversed(self.blocks.values())).intermediates_outputs + return self.intermediates_outputs + @torch.no_grad() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + for block_name, block in self.blocks.items(): + try: + pipeline, state = block(pipeline, state) + except Exception as e: + error_msg = ( + f"\nError in block: ({block_name}, {block.__class__.__name__})\n" + f"Error details: {str(e)}\n" + f"Traceback:\n{traceback.format_exc()}" + ) + logger.error(error_msg) + raise + return pipeline, state + + def _get_trigger_inputs(self): + """ + Returns a set of all unique trigger input values found in the blocks. + Returns: Set[str] containing all unique block_trigger_inputs values + """ + def fn_recursive_get_trigger(blocks): + trigger_values = set() + + if blocks is not None: + for name, block in blocks.items(): + # Check if current block has trigger inputs(i.e. auto block) + if hasattr(block, 'block_trigger_inputs') and block.block_trigger_inputs is not None: + # Add all non-None values from the trigger inputs list + trigger_values.update(t for t in block.block_trigger_inputs if t is not None) + + # If block has blocks, recursively check them + if hasattr(block, 'blocks'): + nested_triggers = fn_recursive_get_trigger(block.blocks) + trigger_values.update(nested_triggers) + + return trigger_values + + return fn_recursive_get_trigger(self.blocks) + + @property + def trigger_inputs(self): + return self._get_trigger_inputs() + + def _traverse_trigger_blocks(self, trigger_inputs): + # Convert trigger_inputs to a set for easier manipulation + active_triggers = set(trigger_inputs) + def fn_recursive_traverse(block, block_name, active_triggers): + result_blocks = OrderedDict() + + # sequential(include loopsequential) or PipelineBlock + if not hasattr(block, 'block_trigger_inputs'): + if hasattr(block, 'blocks'): + # sequential or LoopSequentialPipelineBlocks (keep traversing) + for sub_block_name, sub_block in block.blocks.items(): + blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers) + blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers) + blocks_to_update = {f"{block_name}.{k}": v for k,v in blocks_to_update.items()} + result_blocks.update(blocks_to_update) + else: + # PipelineBlock + result_blocks[block_name] = block + # Add this block's output names to active triggers if defined + if hasattr(block, 'outputs'): + active_triggers.update(out.name for out in block.outputs) + return result_blocks + + # auto + else: + # Find first block_trigger_input that matches any value in our active_triggers + this_block = None + matching_trigger = None + for trigger_input in block.block_trigger_inputs: + if trigger_input is not None and trigger_input in active_triggers: + this_block = block.trigger_to_block_map[trigger_input] + matching_trigger = trigger_input + break + + # If no matches found, try to get the default (None) block + if this_block is None and None in block.block_trigger_inputs: + this_block = block.trigger_to_block_map[None] + matching_trigger = None + + if this_block is not None: + # sequential/auto (keep traversing) + if hasattr(this_block, 'blocks'): + result_blocks.update(fn_recursive_traverse(this_block, block_name, active_triggers)) + else: + # PipelineBlock + result_blocks[block_name] = this_block + # Add this block's output names to active triggers if defined + # YiYi TODO: do we need outputs here? can it just be intermediate_outputs? can we get rid of outputs attribute? + if hasattr(this_block, 'outputs'): + active_triggers.update(out.name for out in this_block.outputs) + + return result_blocks + + all_blocks = OrderedDict() + for block_name, block in self.blocks.items(): + blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers) + all_blocks.update(blocks_to_update) + return all_blocks + + def get_execution_blocks(self, *trigger_inputs): + trigger_inputs_all = self.trigger_inputs + + if trigger_inputs is not None: + + if not isinstance(trigger_inputs, (list, tuple, set)): + trigger_inputs = [trigger_inputs] + invalid_inputs = [x for x in trigger_inputs if x not in trigger_inputs_all] + if invalid_inputs: + logger.warning( + f"The following trigger inputs will be ignored as they are not supported: {invalid_inputs}" + ) + trigger_inputs = [x for x in trigger_inputs if x in trigger_inputs_all] + + if trigger_inputs is None: + if None in trigger_inputs_all: + trigger_inputs = [None] + else: + trigger_inputs = [trigger_inputs_all[0]] + blocks_triggered = self._traverse_trigger_blocks(trigger_inputs) + return SequentialPipelineBlocks.from_blocks_dict(blocks_triggered) + + def __repr__(self): + class_name = self.__class__.__name__ + base_class = self.__class__.__bases__[0].__name__ + header = ( + f"{class_name}(\n Class: {base_class}\n" + if base_class and base_class != "object" + else f"{class_name}(\n" + ) + + + if self.trigger_inputs: + header += "\n" + header += " " + "=" * 100 + "\n" + header += " This pipeline contains blocks that are selected at runtime based on inputs.\n" + header += f" Trigger Inputs: {self.trigger_inputs}\n" + # Get first trigger input as example + example_input = next(t for t in self.trigger_inputs if t is not None) + header += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n" + header += " " + "=" * 100 + "\n\n" + + # Format description with proper indentation + desc_lines = self.description.split('\n') + desc = [] + # First line with "Description:" label + desc.append(f" Description: {desc_lines[0]}") + # Subsequent lines with proper indentation + if len(desc_lines) > 1: + desc.extend(f" {line}" for line in desc_lines[1:]) + desc = '\n'.join(desc) + '\n' + + # Components section - focus only on expected components + expected_components = getattr(self, "expected_components", []) + components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) + + # Configs section - use format_configs with add_empty_lines=False + expected_configs = getattr(self, "expected_configs", []) + configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) + + # Blocks section - moved to the end with simplified format + blocks_str = " Blocks:\n" + for i, (name, block) in enumerate(self.blocks.items()): + # Get trigger input for this block + trigger = None + if hasattr(self, 'block_to_trigger_map'): + trigger = self.block_to_trigger_map.get(name) + # Format the trigger info + if trigger is None: + trigger_str = "[default]" + elif isinstance(trigger, (list, tuple)): + trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]" + else: + trigger_str = f"[trigger: {trigger}]" + # For AutoPipelineBlocks, add bullet points + blocks_str += f" • {name} {trigger_str} ({block.__class__.__name__})\n" + else: + # For SequentialPipelineBlocks, show execution order + blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" + + # Add block description + desc_lines = block.description.split('\n') + indented_desc = desc_lines[0] + if len(desc_lines) > 1: + indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) + blocks_str += f" Description: {indented_desc}\n\n" + + # Build the representation with conditional sections + result = f"{header}\n{desc}" + + # Only add components section if it has content + if components_str.strip(): + result += f"\n\n{components_str}" + + # Only add configs section if it has content + if configs_str.strip(): + result += f"\n\n{configs_str}" + + # Always add blocks section + result += f"\n\n{blocks_str})" + + return result + + + @property + def doc(self): + return make_doc_string( + self.inputs, + self.intermediates_inputs, + self.outputs, + self.description, + class_name=self.__class__.__name__, + expected_components=self.expected_components, + expected_configs=self.expected_configs + ) + +#YiYi TODO: __repr__ +class LoopSequentialPipelineBlocks(ModularPipelineBlocks): + """ + A class that combines multiple pipeline block classes into a For Loop. When called, it will call each block in sequence. + """ + + model_name = None + block_classes = [] + block_names = [] + + @property + def description(self) -> str: + """Description of the block. Must be implemented by subclasses.""" + raise NotImplementedError("description method must be implemented in subclasses") + + @property + def loop_expected_components(self) -> List[ComponentSpec]: + return [] + + @property + def loop_expected_configs(self) -> List[ConfigSpec]: + return [] + + @property + def loop_inputs(self) -> List[InputParam]: + """List of input parameters. Must be implemented by subclasses.""" + return [] + + @property + def loop_intermediates_inputs(self) -> List[InputParam]: + """List of intermediate input parameters. Must be implemented by subclasses.""" + return [] + + @property + def loop_intermediates_outputs(self) -> List[OutputParam]: + """List of intermediate output parameters. Must be implemented by subclasses.""" + return [] + + + @property + def loop_required_inputs(self) -> List[str]: + input_names = [] + for input_param in self.loop_inputs: + if input_param.required: + input_names.append(input_param.name) + return input_names + + @property + def loop_required_intermediates_inputs(self) -> List[str]: + input_names = [] + for input_param in self.loop_intermediates_inputs: + if input_param.required: + input_names.append(input_param.name) + return input_names + + # modified from SequentialPipelineBlocks to include loop_expected_components + @property + def expected_components(self): + expected_components = [] + for block in self.blocks.values(): + for component in block.expected_components: + if component not in expected_components: + expected_components.append(component) + for component in self.loop_expected_components: + if component not in expected_components: + expected_components.append(component) + return expected_components + + # modified from SequentialPipelineBlocks to include loop_expected_configs + @property + def expected_configs(self): + expected_configs = [] + for block in self.blocks.values(): + for config in block.expected_configs: + if config not in expected_configs: + expected_configs.append(config) + for config in self.loop_expected_configs: + if config not in expected_configs: + expected_configs.append(config) + return expected_configs + + # modified from SequentialPipelineBlocks to include loop_inputs + def get_inputs(self): + named_inputs = [(name, block.inputs) for name, block in self.blocks.items()] + named_inputs.append(("loop", self.loop_inputs)) + combined_inputs = combine_inputs(*named_inputs) + # mark Required inputs only if that input is required any of the blocks + for input_param in combined_inputs: + if input_param.name in self.required_inputs: + input_param.required = True + else: + input_param.required = False + return combined_inputs + + # Copied from SequentialPipelineBlocks + @property + def inputs(self): + return self.get_inputs() + + + # modified from SequentialPipelineBlocks to include loop_intermediates_inputs + @property + def intermediates_inputs(self): + intermediates = self.get_intermediates_inputs() + intermediate_names = [input.name for input in intermediates] + for loop_intermediate_input in self.loop_intermediates_inputs: + if loop_intermediate_input.name not in intermediate_names: + intermediates.append(loop_intermediate_input) + return intermediates + + + # Copied from SequentialPipelineBlocks + def get_intermediates_inputs(self): + inputs = [] + outputs = set() + + # Go through all blocks in order + for block in self.blocks.values(): + # Add inputs that aren't in outputs yet + inputs.extend(input_name for input_name in block.intermediates_inputs if input_name.name not in outputs) + + # Only add outputs if the block cannot be skipped + should_add_outputs = True + if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs: + should_add_outputs = False + + if should_add_outputs: + # Add this block's outputs + block_intermediates_outputs = [out.name for out in block.intermediates_outputs] + outputs.update(block_intermediates_outputs) + return inputs + + + # modified from SequentialPipelineBlocks, if any additionan input required by the loop is required by the block + @property + def required_inputs(self) -> List[str]: + # Get the first block from the dictionary + first_block = next(iter(self.blocks.values())) + required_by_any = set(getattr(first_block, "required_inputs", set())) + + required_by_loop = set(getattr(self, "loop_required_inputs", set())) + required_by_any.update(required_by_loop) + + # Union with required inputs from all other blocks + for block in list(self.blocks.values())[1:]: + block_required = set(getattr(block, "required_inputs", set())) + required_by_any.update(block_required) + + return list(required_by_any) + + # YiYi TODO: maybe we do not need this, it is only used in docstring, + # intermediate_inputs is by default required, unless you manually handle it inside the block + @property + def required_intermediates_inputs(self) -> List[str]: + required_intermediates_inputs = [] + for input_param in self.intermediates_inputs: + if input_param.required: + required_intermediates_inputs.append(input_param.name) + for input_param in self.loop_intermediates_inputs: + if input_param.required: + required_intermediates_inputs.append(input_param.name) + return required_intermediates_inputs + + + # YiYi TODO: this need to be thought about more + # modified from SequentialPipelineBlocks to include loop_intermediates_outputs + @property + def intermediates_outputs(self) -> List[str]: + named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] + combined_outputs = combine_outputs(*named_outputs) + for output in self.loop_intermediates_outputs: + if output.name not in set([output.name for output in combined_outputs]): + combined_outputs.append(output) + return combined_outputs + + # YiYi TODO: this need to be thought about more + # copied from SequentialPipelineBlocks + @property + def outputs(self) -> List[str]: + return next(reversed(self.blocks.values())).intermediates_outputs + + + def __init__(self): + blocks = OrderedDict() + for block_name, block_cls in zip(self.block_names, self.block_classes): + blocks[block_name] = block_cls() + self.blocks = blocks + + @classmethod + def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "LoopSequentialPipelineBlocks": + """Creates a LoopSequentialPipelineBlocks instance from a dictionary of blocks. + + Args: + blocks_dict: Dictionary mapping block names to block instances + + Returns: + A new LoopSequentialPipelineBlocks instance + """ + instance = cls() + instance.block_classes = [block.__class__ for block in blocks_dict.values()] + instance.block_names = list(blocks_dict.keys()) + instance.blocks = blocks_dict + return instance + + def loop_step(self, components, state: PipelineState, **kwargs): + + for block_name, block in self.blocks.items(): + try: + components, state = block(components, state, **kwargs) + except Exception as e: + error_msg = ( + f"\nError in block: ({block_name}, {block.__class__.__name__})\n" + f"Error details: {str(e)}\n" + f"Traceback:\n{traceback.format_exc()}" + ) + logger.error(error_msg) + raise + return components, state + + def __call__(self, components, state: PipelineState) -> PipelineState: + raise NotImplementedError("`__call__` method needs to be implemented by the subclass") + + + def get_block_state(self, state: PipelineState) -> dict: + """Get all inputs and intermediates in one dictionary""" + data = {} + + # Check inputs + for input_param in self.inputs: + if input_param.name: + value = state.get_input(input_param.name) + if input_param.required and value is None: + raise ValueError(f"Required input '{input_param.name}' is missing") + elif value is not None or (value is None and input_param.name not in data): + data[input_param.name] = value + elif input_param.kwargs_type: + # if kwargs_type is provided, get all inputs with matching kwargs_type + if input_param.kwargs_type not in data: + data[input_param.kwargs_type] = {} + inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type) + if inputs_kwargs: + for k, v in inputs_kwargs.items(): + if v is not None: + data[k] = v + data[input_param.kwargs_type][k] = v + + # Check intermediates + for input_param in self.intermediates_inputs: + if input_param.name: + value = state.get_intermediate(input_param.name) + if input_param.required and value is None: + raise ValueError(f"Required intermediate input '{input_param.name}' is missing") + elif value is not None or (value is None and input_param.name not in data): + data[input_param.name] = value + elif input_param.kwargs_type: + # if kwargs_type is provided, get all intermediates with matching kwargs_type + if input_param.kwargs_type not in data: + data[input_param.kwargs_type] = {} + intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type) + if intermediates_kwargs: + for k, v in intermediates_kwargs.items(): + if v is not None: + if k not in data: + data[k] = v + data[input_param.kwargs_type][k] = v + return BlockState(**data) + + def add_block_state(self, state: PipelineState, block_state: BlockState): + for output_param in self.intermediates_outputs: + if not hasattr(block_state, output_param.name): + raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") + param = getattr(block_state, output_param.name) + state.add_intermediate(output_param.name, param, output_param.kwargs_type) + + for input_param in self.intermediates_inputs: + if input_param.name and hasattr(block_state, input_param.name): + param = getattr(block_state, input_param.name) + # Only add if the value is different from what's in the state + current_value = state.get_intermediate(input_param.name) + if current_value is not param: # Using identity comparison to check if object was modified + state.add_intermediate(input_param.name, param, input_param.kwargs_type) + elif input_param.kwargs_type: + # if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters + # we need to first find out which inputs are and loop through them. + intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type) + for param_name, current_value in intermediates_kwargs.items(): + if not hasattr(block_state, param_name): + continue + param = getattr(block_state, param_name) + if current_value is not param: # Using identity comparison to check if object was modified + state.add_intermediate(param_name, param, input_param.kwargs_type) + + + @property + def doc(self): + return make_doc_string( + self.inputs, + self.intermediates_inputs, + self.outputs, + self.description, + class_name=self.__class__.__name__, + expected_components=self.expected_components, + expected_configs=self.expected_configs + ) + + # modified from SequentialPipelineBlocks, + #(does not need trigger_inputs related part so removed them, + # do not need to support auto block for loop blocks) + def __repr__(self): + class_name = self.__class__.__name__ + base_class = self.__class__.__bases__[0].__name__ + header = ( + f"{class_name}(\n Class: {base_class}\n" + if base_class and base_class != "object" + else f"{class_name}(\n" + ) + + # Format description with proper indentation + desc_lines = self.description.split('\n') + desc = [] + # First line with "Description:" label + desc.append(f" Description: {desc_lines[0]}") + # Subsequent lines with proper indentation + if len(desc_lines) > 1: + desc.extend(f" {line}" for line in desc_lines[1:]) + desc = '\n'.join(desc) + '\n' + + # Components section - focus only on expected components + expected_components = getattr(self, "expected_components", []) + components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) + + # Configs section - use format_configs with add_empty_lines=False + expected_configs = getattr(self, "expected_configs", []) + configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) + + # Blocks section - moved to the end with simplified format + blocks_str = " Blocks:\n" + for i, (name, block) in enumerate(self.blocks.items()): + + # For SequentialPipelineBlocks, show execution order + blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" + + # Add block description + desc_lines = block.description.split('\n') + indented_desc = desc_lines[0] + if len(desc_lines) > 1: + indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) + blocks_str += f" Description: {indented_desc}\n\n" + + # Build the representation with conditional sections + result = f"{header}\n{desc}" + + # Only add components section if it has content + if components_str.strip(): + result += f"\n\n{components_str}" + + # Only add configs section if it has content + if configs_str.strip(): + result += f"\n\n{configs_str}" + + # Always add blocks section + result += f"\n\n{blocks_str})" + + return result + + @torch.compiler.disable + def progress_bar(self, iterable=None, total=None): + if not hasattr(self, "_progress_bar_config"): + self._progress_bar_config = {} + elif not isinstance(self._progress_bar_config, dict): + raise ValueError( + f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." + ) + + if iterable is not None: + return tqdm(iterable, **self._progress_bar_config) + elif total is not None: + return tqdm(total=total, **self._progress_bar_config) + else: + raise ValueError("Either `total` or `iterable` has to be defined.") + + def set_progress_bar_config(self, **kwargs): + self._progress_bar_config = kwargs + + +# YiYi TODO: +# 1. move the modular_repo arg and the logic to fetch info from repo out of __init__ so that __init__ alwasy create an default modular_model_index config +# 2. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess) +# 3. do we need ConfigSpec? seems pretty unnecessrary for loader, can just add and kwargs to the loader +# 4. add validator for methods where we accpet kwargs to be passed to from_pretrained() +class ModularLoader(ConfigMixin, PushToHubMixin): + """ + Base class for all Modular pipelines loaders. + + """ + config_name = "modular_model_index.json" + + + def register_components(self, **kwargs): + """ + Register components with their corresponding specifications. + + This method is responsible for: + 1. Sets component objects as attributes on the loader (e.g., self.unet = unet) + 2. Updates the modular_model_index.json configuration for serialization + 4. Adds components to the component manager if one is attached + + This method is called when: + - Components are first initialized in __init__: + - from_pretrained components not loaded during __init__ so they are registered as None; + - non from_pretrained components are created during __init__ and registered as the object itself + - Components are updated with the `update()` method: e.g. loader.update(unet=unet) or loader.update(guider=guider_spec) + - (from_pretrained) Components are loaded with the `load()` method: e.g. loader.load(component_names=["unet"]) + + Args: + **kwargs: Keyword arguments where keys are component names and values are component objects. + E.g., register_components(unet=unet_model, text_encoder=encoder_model) + + Notes: + - Components must be created from ComponentSpec (have _diffusers_load_id attribute) + - When registering None for a component, it updates the modular_model_index.json config but sets attribute to None + """ + for name, module in kwargs.items(): + # current component spec + component_spec = self._component_specs.get(name) + if component_spec is None: + logger.warning(f"ModularLoader.register_components: skipping unknown component '{name}'") + continue + + # check if it is the first time registration, i.e. calling from __init__ + is_registered = hasattr(self, name) + + # make sure the component is created from ComponentSpec + if module is not None and not hasattr(module, "_diffusers_load_id"): + raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") + + if module is not None: + # actual library and class name of the module + library, class_name = _fetch_class_library_tuple(module) # e.g. ("diffusers", "UNet2DConditionModel") + + # extract the loading spec from the updated component spec that'll be used as part of modular_model_index.json config + # e.g. {"repo": "stabilityai/stable-diffusion-2-1", + # "type_hint": ("diffusers", "UNet2DConditionModel"), + # "subfolder": "unet", + # "variant": None, + # "revision": None} + component_spec_dict = self._component_spec_to_dict(component_spec) + + else: + # if module is None, e.g. self.register_components(unet=None) during __init__ + # we do not update the spec, + # but we still need to update the modular_model_index.json config based oncomponent spec + library, class_name = None, None + component_spec_dict = self._component_spec_to_dict(component_spec) + register_dict = {name: (library, class_name, component_spec_dict)} + + # set the component as attribute + # if it is not set yet, just set it and skip the process to check and warn below + if not is_registered: + self.register_to_config(**register_dict) + setattr(self, name, module) + if module is not None and module._diffusers_load_id != "null" and self._component_manager is not None: + self._component_manager.add(name, module, self._collection) + continue + + current_module = getattr(self, name, None) + # skip if the component is already registered with the same object + if current_module is module: + logger.info(f"ModularLoader.register_components: {name} is already registered with same object, skipping") + continue + + # warn if unregister + if current_module is not None and module is None: + logger.info( + f"ModularLoader.register_components: setting '{name}' to None " + f"(was {current_module.__class__.__name__})" + ) + # same type, new instance → replace but send debug log + elif current_module is not None \ + and module is not None \ + and isinstance(module, current_module.__class__) \ + and current_module != module: + logger.debug( + f"ModularLoader.register_components: replacing existing '{name}' " + f"(same type {type(current_module).__name__}, new instance)" + ) + + # update modular_model_index.json config + self.register_to_config(**register_dict) + # finally set models + setattr(self, name, module) + # add to component manager if one is attached + if module is not None and module._diffusers_load_id != "null" and self._component_manager is not None: + self._component_manager.add(name, module, self._collection) + + + + # YiYi TODO: add warning for passing multiple ComponentSpec/ConfigSpec with the same name + def __init__(self, specs: List[Union[ComponentSpec, ConfigSpec]], pretrained_model_name_or_path: Optional[str] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs): + """ + Initialize the loader with a list of component specs and config specs. + """ + self._component_manager = component_manager + self._collection = collection + self._component_specs = { + spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ComponentSpec) + } + self._config_specs = { + spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ConfigSpec) + } + + # update component_specs and config_specs from modular_repo + if pretrained_model_name_or_path is not None: + config_dict = self.load_config(pretrained_model_name_or_path, **kwargs) + + for name, value in config_dict.items(): + # only update component_spec for from_pretrained components + if name in self._component_specs and self._component_specs[name].default_creation_method == "from_pretrained" and isinstance(value, (tuple, list)) and len(value) == 3: + library, class_name, component_spec_dict = value + component_spec = self._dict_to_component_spec(name, component_spec_dict) + self._component_specs[name] = component_spec + + elif name in self._config_specs: + self._config_specs[name].default = value + + register_components_dict = {} + for name, component_spec in self._component_specs.items(): + if component_spec.default_creation_method == "from_config": + component = component_spec.create() + else: + component = None + register_components_dict[name] = component + self.register_components(**register_components_dict) + + default_configs = {} + for name, config_spec in self._config_specs.items(): + default_configs[name] = config_spec.default + self.register_to_config(**default_configs) + + + @property + def device(self) -> torch.device: + r""" + Returns: + `torch.device`: The torch device on which the pipeline is located. + """ + modules = self.components.values() + modules = [m for m in modules if isinstance(m, torch.nn.Module)] + + for module in modules: + return module.device + + return torch.device("cpu") + + @property + # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + [`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from + Accelerate's module hooks. + """ + for name, model in self.components.items(): + if not isinstance(model, torch.nn.Module): + continue + + if not hasattr(model, "_hf_hook"): + return self.device + for module in model.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + + @property + def dtype(self) -> torch.dtype: + r""" + Returns: + `torch.dtype`: The torch dtype on which the pipeline is located. + """ + modules = self.components.values() + modules = [m for m in modules if isinstance(m, torch.nn.Module)] + + for module in modules: + return module.dtype + + return torch.float32 + + + @property + def components(self) -> Dict[str, Any]: + # return only components we've actually set as attributes on self + return { + name: getattr(self, name) + for name in self._component_specs.keys() + if hasattr(self, name) + } + + def update(self, **kwargs): + """ + Update components and configs after instance creation. + + Args: + + """ + """ + Update components and configuration values after the loader has been instantiated. + + This method allows you to: + 1. Replace existing components with new ones (e.g., updating the unet or text_encoder) + 2. Update configuration values (e.g., changing requires_safety_checker flag) + + Args: + **kwargs: Component objects or configuration values to update: + - Component objects: Must be created using ComponentSpec (e.g., `unet=new_unet, text_encoder=new_encoder`) + - Configuration values: Simple values to update configuration settings (e.g., `requires_safety_checker=False`) + - ComponentSpec objects: if passed a ComponentSpec object, only support from_config spec, will call create() method to create it + + Raises: + ValueError: If a component wasn't created using ComponentSpec (doesn't have `_diffusers_load_id` attribute) + + Examples: + ```python + # Update multiple components at once + loader.update( + unet=new_unet_model, + text_encoder=new_text_encoder + ) + + # Update configuration values + loader.update( + requires_safety_checker=False, + guidance_rescale=0.7 + ) + + # Update both components and configs together + loader.update( + unet=new_unet_model, + requires_safety_checker=False + ) + # update with ComponentSpec objects + loader.update( + guider=ComponentSpec(name="guider", type_hint=ClassifierFreeGuidance, config={"guidance_scale": 5.0}, default_creation_method="from_config") + ) + ``` + """ + + # extract component_specs_updates & config_specs_updates from `specs` + passed_component_specs = {k: kwargs.pop(k) for k in self._component_specs if k in kwargs and isinstance(kwargs[k], ComponentSpec)} + passed_components = {k: kwargs.pop(k) for k in self._component_specs if k in kwargs and not isinstance(kwargs[k], ComponentSpec)} + passed_config_values = {k: kwargs.pop(k) for k in self._config_specs if k in kwargs} + + for name, component in passed_components.items(): + if not hasattr(component, "_diffusers_load_id"): + raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") + + # YiYi TODO: remove this if we remove support for non config mixin components in `create()` method + if component._diffusers_load_id == "null" and not isinstance(component, ConfigMixin): + raise ValueError( + f"The passed component '{name}' is not supported in update() method " + f"because it is not supported in `ComponentSpec.from_component()`. " + f"Please pass a ComponentSpec object instead." + ) + current_component_spec = self._component_specs[name] + # warn if type changed + if current_component_spec.type_hint is not None and not isinstance(component, current_component_spec.type_hint): + logger.warning(f"ModularLoader.update: adding {name} with new type: {component.__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}") + # update _component_specs based on the new component + new_component_spec = ComponentSpec.from_component(name, component) + self._component_specs[name] = new_component_spec + + if len(kwargs) > 0: + logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}") + + created_components = {} + for name, component_spec in passed_component_specs.items(): + if component_spec.default_creation_method == "from_pretrained": + raise ValueError(f"ComponentSpec object with default_creation_method == 'from_pretrained' is not supported in update() method") + created_components[name] = component_spec.create() + current_component_spec = self._component_specs[name] + # warn if type changed + if current_component_spec.type_hint is not None and not isinstance(created_components[name], current_component_spec.type_hint): + logger.warning(f"ModularLoader.update: adding {name} with new type: {created_components[name].__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}") + # update _component_specs based on the user passed component_spec + self._component_specs[name] = component_spec + self.register_components(**passed_components, **created_components) + + + config_to_register = {} + for name, new_value in passed_config_values.items(): + + # e.g. requires_aesthetics_score = False + self._config_specs[name].default = new_value + config_to_register[name] = new_value + self.register_to_config(**config_to_register) + + + # YiYi TODO: support map for additional from_pretrained kwargs + def load(self, component_names: Optional[List[str]] = None, **kwargs): + """ + Load selectedcomponents from specs. + + Args: + component_names: List of component names to load + **kwargs: additional kwargs to be passed to `from_pretrained()`.Can be: + - a single value to be applied to all components to be loaded, e.g. torch_dtype=torch.bfloat16 + - a dict, e.g. torch_dtype={"unet": torch.bfloat16, "default": torch.float32} + - if potentially override ComponentSpec if passed a different loading field in kwargs, e.g. `repo`, `variant`, `revision`, etc. + """ + # if not specific name, load all the components with default_creation_method == "from_pretrained" + if component_names is None: + component_names = [name for name in self._component_specs.keys() if self._component_specs[name].default_creation_method == "from_pretrained"] + elif not isinstance(component_names, list): + component_names = [component_names] + + components_to_load = set([name for name in component_names if name in self._component_specs]) + unknown_component_names = set([name for name in component_names if name not in self._component_specs]) + if len(unknown_component_names) > 0: + logger.warning(f"Unknown components will be ignored: {unknown_component_names}") + + components_to_register = {} + for name in components_to_load: + spec = self._component_specs[name] + component_load_kwargs = {} + for key, value in kwargs.items(): + if not isinstance(value, dict): + # if the value is a single value, apply it to all components + component_load_kwargs[key] = value + else: + if name in value: + # if it is a dict, check if the component name is in the dict + component_load_kwargs[key] = value[name] + elif "default" in value: + # check if the default is specified + component_load_kwargs[key] = value["default"] + try: + components_to_register[name] = spec.load(**component_load_kwargs) + except Exception as e: + logger.warning(f"Failed to create component '{name}': {e}") + + # Register all components at once + self.register_components(**components_to_register) + + # YiYi TODO: should support to method + def to(self, *args, **kwargs): + pass + + # YiYi TODO: + # 1. should support save some components too! currently only modular_model_index.json is saved + # 2. maybe order the json file to make it more readable: configs first, then components + def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, spec_only: bool = True, **kwargs): + + component_names = list(self._component_specs.keys()) + config_names = list(self._config_specs.keys()) + self.register_to_config(_components_names=component_names, _configs_names=config_names) + self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) + config = dict(self.config) + config.pop("_components_names", None) + config.pop("_configs_names", None) + self._internal_dict = FrozenDict(config) + + + @classmethod + @validate_hf_hub_args + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], spec_only: bool = True, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs): + + config_dict = cls.load_config(pretrained_model_name_or_path, **kwargs) + expected_component = set(config_dict.pop("_components_names")) + expected_config = set(config_dict.pop("_configs_names")) + + component_specs = [] + config_specs = [] + for name, value in config_dict.items(): + if name in expected_component and isinstance(value, (tuple, list)) and len(value) == 3: + library, class_name, component_spec_dict = value + # only pick up pretrained components from the repo + if component_spec_dict.get("repo", None) is not None: + component_spec = cls._dict_to_component_spec(name, component_spec_dict) + component_specs.append(component_spec) + + elif name in expected_config: + config_specs.append(ConfigSpec(name=name, default=value)) + + return cls(component_specs + config_specs, component_manager=component_manager, collection=collection) + + + @staticmethod + def _component_spec_to_dict(component_spec: ComponentSpec) -> Any: + """ + Convert a ComponentSpec into a JSON‐serializable dict for saving in + `modular_model_index.json`. + + This dict contains: + - "type_hint": Tuple[str, str] + The fully‐qualified module path and class name of the component. + - All loading fields defined by `component_spec.loading_fields()`, typically: + - "repo": Optional[str] + The model repository (e.g., "stabilityai/stable-diffusion-xl"). + - "subfolder": Optional[str] + A subfolder within the repo where this component lives. + - "variant": Optional[str] + An optional variant identifier for the model. + - "revision": Optional[str] + A specific git revision (commit hash, tag, or branch). + - ... any other loading fields defined on the spec. + + Args: + component_spec (ComponentSpec): + The spec object describing one pipeline component. + + Returns: + Dict[str, Any]: A mapping suitable for JSON serialization. + + Example: + >>> from diffusers.pipelines.modular_pipeline_utils import ComponentSpec + >>> from diffusers.models.unet import UNet2DConditionModel + >>> spec = ComponentSpec( + ... name="unet", + ... type_hint=UNet2DConditionModel, + ... config=None, + ... repo="path/to/repo", + ... subfolder="subfolder", + ... variant=None, + ... revision=None, + ... default_creation_method="from_pretrained", + ... ) + >>> ModularLoader._component_spec_to_dict(spec) + { + "type_hint": ("diffusers.models.unet", "UNet2DConditionModel"), + "repo": "path/to/repo", + "subfolder": "subfolder", + "variant": None, + "revision": None, + } + """ + if component_spec.type_hint is not None: + lib_name, cls_name = _fetch_class_library_tuple(component_spec.type_hint) + else: + lib_name = None + cls_name = None + load_spec_dict = {k: getattr(component_spec, k) for k in component_spec.loading_fields()} + return { + "type_hint": (lib_name, cls_name), + **load_spec_dict, + } + + @staticmethod + def _dict_to_component_spec( + name: str, + spec_dict: Dict[str, Any], + ) -> ComponentSpec: + """ + Reconstruct a ComponentSpec from a dict. + """ + # make a shallow copy so we can pop() safely + spec_dict = spec_dict.copy() + # pull out and resolve the stored type_hint + lib_name, cls_name = spec_dict.pop("type_hint") + if lib_name is not None and cls_name is not None: + type_hint = simple_get_class_obj(lib_name, cls_name) + else: + type_hint = None + + # re‐assemble the ComponentSpec + return ComponentSpec( + name=name, + type_hint=type_hint, + **spec_dict, + ) + + +class ModularPipeline: + """ + Base class for all Modular pipelines. + + Args: + blocks: ModularPipelineBlocks, the blocks to be used in the pipeline + loader: ModularLoader, the loader to be used in the pipeline + """ + + def __init__(self, blocks: ModularPipelineBlocks, loader: ModularLoader): + self.blocks = blocks + self.loader = loader + + def __repr__(self): + blocks_class = self.blocks.__class__.__name__ + loader_class = self.loader.__class__.__name__ + return f"ModularPipeline(blocks={blocks_class}, loader={loader_class})" + + @property + def default_call_parameters(self) -> Dict[str, Any]: + params = {} + for input_param in self.blocks.inputs: + params[input_param.name] = input_param.default + return params + + def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs): + """ + Run one or more blocks in sequence, optionally you can pass a previous pipeline state. + """ + if state is None: + state = PipelineState() + + + # Make a copy of the input kwargs + passed_kwargs = kwargs.copy() + + + # Add inputs to state, using defaults if not provided in the kwargs or the state + # if same input already in the state, will override it if provided in the kwargs + + intermediates_inputs = [inp.name for inp in self.blocks.intermediates_inputs] + for expected_input_param in self.blocks.inputs: + name = expected_input_param.name + default = expected_input_param.default + kwargs_type = expected_input_param.kwargs_type + if name in passed_kwargs: + if name not in intermediates_inputs: + state.add_input(name, passed_kwargs.pop(name), kwargs_type) + else: + state.add_input(name, passed_kwargs[name], kwargs_type) + elif name not in state.inputs: + state.add_input(name, default, kwargs_type) + + for expected_intermediate_param in self.blocks.intermediates_inputs: + name = expected_intermediate_param.name + kwargs_type = expected_intermediate_param.kwargs_type + if name in passed_kwargs: + state.add_intermediate(name, passed_kwargs.pop(name), kwargs_type) + + # Warn about unexpected inputs + if len(passed_kwargs) > 0: + warnings.warn(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.") + # Run the pipeline + with torch.no_grad(): + try: + pipeline, state = self.blocks(self.loader, state) + except Exception: + error_msg = f"Error in block: ({self.blocks.__class__.__name__}):\n" + logger.error(error_msg) + raise + + if output is None: + return state + + + elif isinstance(output, str): + return state.get_intermediate(output) + + elif isinstance(output, (list, tuple)): + return state.get_intermediates(output) + else: + raise ValueError(f"Output '{output}' is not a valid output type") + + + def load_components(self, component_names: Optional[List[str]] = None, **kwargs): + self.loader.load(component_names=component_names, **kwargs) + + def update_components(self, **kwargs): + self.loader.update(**kwargs) + + @classmethod + @validate_hf_hub_args + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], trust_remote_code: Optional[bool] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs): + blocks = ModularPipelineBlocks.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs) + pipeline = blocks.init_pipeline(pretrained_model_name_or_path, component_manager=component_manager, collection=collection, **kwargs) + return pipeline + + def save_pretrained(self, save_directory: Optional[Union[str, os.PathLike]] = None, push_to_hub: bool = False, **kwargs): + self.blocks.save_pretrained(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) + self.loader.save_pretrained(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) + + + @property + def doc(self): + return self.blocks.doc \ No newline at end of file diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py new file mode 100644 index 000000000000..ced059551f9a --- /dev/null +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -0,0 +1,616 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import inspect +from dataclasses import dataclass, asdict, field, fields +from typing import Any, Dict, List, Optional, Tuple, Type, Union, Literal + +from ..utils.import_utils import is_torch_available +from ..configuration_utils import FrozenDict, ConfigMixin +from collections import OrderedDict + +if is_torch_available(): + import torch + + +class InsertableOrderedDict(OrderedDict): + def insert(self, key, value, index): + items = list(self.items()) + + # Remove key if it already exists to avoid duplicates + items = [(k, v) for k, v in items if k != key] + + # Insert at the specified index + items.insert(index, (key, value)) + + # Clear and update self + self.clear() + self.update(items) + + # Return self for method chaining + return self + + +# YiYi TODO: +# 1. validate the dataclass fields +# 2. add a validator for create_* methods, make sure they are valid inputs to pass to from_pretrained() +@dataclass +class ComponentSpec: + """Specification for a pipeline component. + + A component can be created in two ways: + 1. From scratch using __init__ with a config dict + 2. using `from_pretrained` + + Attributes: + name: Name of the component + type_hint: Type of the component (e.g. UNet2DConditionModel) + description: Optional description of the component + config: Optional config dict for __init__ creation + repo: Optional repo path for from_pretrained creation + subfolder: Optional subfolder in repo + variant: Optional variant in repo + revision: Optional revision in repo + default_creation_method: Preferred creation method - "from_config" or "from_pretrained" + """ + name: Optional[str] = None + type_hint: Optional[Type] = None + description: Optional[str] = None + config: Optional[FrozenDict[str, Any]] = None + # YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name + repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True}) + subfolder: Optional[str] = field(default=None, metadata={"loading": True}) + variant: Optional[str] = field(default=None, metadata={"loading": True}) + revision: Optional[str] = field(default=None, metadata={"loading": True}) + default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained" + + + def __hash__(self): + """Make ComponentSpec hashable, using load_id as the hash value.""" + return hash((self.name, self.load_id, self.default_creation_method)) + + def __eq__(self, other): + """Compare ComponentSpec objects based on name and load_id.""" + if not isinstance(other, ComponentSpec): + return False + return (self.name == other.name and + self.load_id == other.load_id and + self.default_creation_method == other.default_creation_method) + + @classmethod + def from_component(cls, name: str, component: Any) -> Any: + """Create a ComponentSpec from a Component created by `create` or `load` method.""" + + if not hasattr(component, "_diffusers_load_id"): + raise ValueError("Component is not created by `create` or `load` method") + # throw a error if component is created with `create` method but not a subclass of ConfigMixin + # YiYi TODO: remove this check if we remove support for non configmixin in `create()` method + if component._diffusers_load_id == "null" and not isinstance(component, ConfigMixin): + raise ValueError( + "We currently only support creating ComponentSpec from a component with " + "created with `ComponentSpec.load` method" + "or created with `ComponentSpec.create` and a subclass of ConfigMixin" + ) + + type_hint = component.__class__ + default_creation_method = "from_config" if component._diffusers_load_id == "null" else "from_pretrained" + + if isinstance(component, ConfigMixin): + config = component.config + else: + config = None + + load_spec = cls.decode_load_id(component._diffusers_load_id) + + return cls(name=name, type_hint=type_hint, config=config, default_creation_method=default_creation_method, **load_spec) + + @classmethod + def loading_fields(cls) -> List[str]: + """ + Return the names of all loading‐related fields + (i.e. those whose field.metadata["loading"] is True). + """ + return [f.name for f in fields(cls) if f.metadata.get("loading", False)] + + + @property + def load_id(self) -> str: + """ + Unique identifier for this spec's pretrained load, + composed of repo|subfolder|variant|revision (no empty segments). + """ + parts = [getattr(self, k) for k in self.loading_fields()] + parts = ["null" if p is None else p for p in parts] + return "|".join(p for p in parts if p) + + @classmethod + def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]: + """ + Decode a load_id string back into a dictionary of loading fields and values. + + Args: + load_id: The load_id string to decode, format: "repo|subfolder|variant|revision" + where None values are represented as "null" + + Returns: + Dict mapping loading field names to their values. e.g. + { + "repo": "path/to/repo", + "subfolder": "subfolder", + "variant": "variant", + "revision": "revision" + } + If a segment value is "null", it's replaced with None. + Returns None if load_id is "null" (indicating component not created with `load` method). + """ + + # Get all loading fields in order + loading_fields = cls.loading_fields() + result = {f: None for f in loading_fields} + + if load_id == "null": + return result + + # Split the load_id + parts = load_id.split("|") + + # Map parts to loading fields by position + for i, part in enumerate(parts): + if i < len(loading_fields): + # Convert "null" string back to None + result[loading_fields[i]] = None if part == "null" else part + + return result + + + # YiYi TODO: I think we should only support ConfigMixin for this method (after we make guider and image_processors config mixin) + # otherwise we cannot do spec -> spec.create() -> component -> ComponentSpec.from_component(component) + # the config info is lost in the process + # remove error check in from_component spec and ModularLoader.update() if we remove support for non configmixin in `create()` method + def create(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any: + """Create component using from_config with config.""" + + if self.type_hint is None or not isinstance(self.type_hint, type): + raise ValueError( + f"`type_hint` is required when using from_config creation method." + ) + + config = config or self.config or {} + + if issubclass(self.type_hint, ConfigMixin): + component = self.type_hint.from_config(config, **kwargs) + else: + signature_params = inspect.signature(self.type_hint.__init__).parameters + init_kwargs = {} + for k, v in config.items(): + if k in signature_params: + init_kwargs[k] = v + for k, v in kwargs.items(): + if k in signature_params: + init_kwargs[k] = v + component = self.type_hint(**init_kwargs) + + component._diffusers_load_id = "null" + if hasattr(component, "config"): + self.config = component.config + + return component + + # YiYi TODO: add guard for type of model, if it is supported by from_pretrained + def load(self, **kwargs) -> Any: + """Load component using from_pretrained.""" + + # select loading fields from kwargs passed from user: e.g. repo, subfolder, variant, revision, note the list could change + passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs} + # merge loading field value in the spec with user passed values to create load_kwargs + load_kwargs = {key: passed_loading_kwargs.get(key, getattr(self, key)) for key in self.loading_fields()} + # repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path + repo = load_kwargs.pop("repo", None) + if repo is None: + raise ValueError(f"`repo` info is required when using `load` method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)") + + if self.type_hint is None: + try: + from diffusers import AutoModel + component = AutoModel.from_pretrained(repo, **load_kwargs, **kwargs) + except Exception as e: + raise ValueError(f"Unable to load {self.name} without `type_hint`: {e}") + # update type_hint if AutoModel load successfully + self.type_hint = component.__class__ + else: + try: + component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs) + except Exception as e: + raise ValueError(f"Unable to load {self.name} using load method: {e}") + + self.repo = repo + for k, v in load_kwargs.items(): + setattr(self, k, v) + component._diffusers_load_id = self.load_id + + return component + + + +@dataclass +class ConfigSpec: + """Specification for a pipeline configuration parameter.""" + name: str + default: Any + description: Optional[str] = None + + +# YiYi Notes: both inputs and intermediates_inputs are InputParam objects +# however some fields are not relevant for intermediates_inputs +# e.g. unlike inputs, required only used in docstring for intermediate_inputs, we do not check if a required intermediate inputs is passed +# default is not used for intermediates_inputs, we only use default from inputs, so it is ignored if it is set for intermediates_inputs +# -> should we use different class for inputs and intermediates_inputs? +@dataclass +class InputParam: + """Specification for an input parameter.""" + name: str = None + type_hint: Any = None + default: Any = None + required: bool = False + description: str = "" + kwargs_type: str = None # YiYi Notes: remove this feature (maybe) + + def __repr__(self): + return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>" + + +@dataclass +class OutputParam: + """Specification for an output parameter.""" + name: str + type_hint: Any = None + description: str = "" + kwargs_type: str = None # YiYi notes: remove this feature (maybe) + + def __repr__(self): + return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>" + + +def format_inputs_short(inputs): + """ + Format input parameters into a string representation, with required params first followed by optional ones. + + Args: + inputs: List of input parameters with 'required' and 'name' attributes, and 'default' for optional params + + Returns: + str: Formatted string of input parameters + + Example: + >>> inputs = [ + ... InputParam(name="prompt", required=True), + ... InputParam(name="image", required=True), + ... InputParam(name="guidance_scale", required=False, default=7.5), + ... InputParam(name="num_inference_steps", required=False, default=50) + ... ] + >>> format_inputs_short(inputs) + 'prompt, image, guidance_scale=7.5, num_inference_steps=50' + """ + required_inputs = [param for param in inputs if param.required] + optional_inputs = [param for param in inputs if not param.required] + + required_str = ", ".join(param.name for param in required_inputs) + optional_str = ", ".join(f"{param.name}={param.default}" for param in optional_inputs) + + inputs_str = required_str + if optional_str: + inputs_str = f"{inputs_str}, {optional_str}" if required_str else optional_str + + return inputs_str + + +def format_intermediates_short(intermediates_inputs, required_intermediates_inputs, intermediates_outputs): + """ + Formats intermediate inputs and outputs of a block into a string representation. + + Args: + intermediates_inputs: List of intermediate input parameters + required_intermediates_inputs: List of required intermediate input names + intermediates_outputs: List of intermediate output parameters + + Returns: + str: Formatted string like: + Intermediates: + - inputs: Required(latents), dtype + - modified: latents # variables that appear in both inputs and outputs + - outputs: images # new outputs only + """ + # Handle inputs + input_parts = [] + for inp in intermediates_inputs: + if inp.name in required_intermediates_inputs: + input_parts.append(f"Required({inp.name})") + else: + if inp.name is None and inp.kwargs_type is not None: + inp_name = "*_" + inp.kwargs_type + else: + inp_name = inp.name + input_parts.append(inp_name) + + # Handle modified variables (appear in both inputs and outputs) + inputs_set = {inp.name for inp in intermediates_inputs} + modified_parts = [] + new_output_parts = [] + + for out in intermediates_outputs: + if out.name in inputs_set: + modified_parts.append(out.name) + else: + new_output_parts.append(out.name) + + result = [] + if input_parts: + result.append(f" - inputs: {', '.join(input_parts)}") + if modified_parts: + result.append(f" - modified: {', '.join(modified_parts)}") + if new_output_parts: + result.append(f" - outputs: {', '.join(new_output_parts)}") + + return "\n".join(result) if result else " (none)" + + +def format_params(params, header="Args", indent_level=4, max_line_length=115): + """Format a list of InputParam or OutputParam objects into a readable string representation. + + Args: + params: List of InputParam or OutputParam objects to format + header: Header text to use (e.g. "Args" or "Returns") + indent_level: Number of spaces to indent each parameter line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + + Returns: + A formatted string representing all parameters + """ + if not params: + return "" + + base_indent = " " * indent_level + param_indent = " " * (indent_level + 4) + desc_indent = " " * (indent_level + 8) + formatted_params = [] + + def get_type_str(type_hint): + if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Union: + types = [t.__name__ if hasattr(t, "__name__") else str(t) for t in type_hint.__args__] + return f"Union[{', '.join(types)}]" + return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint) + + def wrap_text(text, indent, max_length): + """Wrap text while preserving markdown links and maintaining indentation.""" + words = text.split() + lines = [] + current_line = [] + current_length = 0 + + for word in words: + word_length = len(word) + (1 if current_line else 0) + + if current_line and current_length + word_length > max_length: + lines.append(" ".join(current_line)) + current_line = [word] + current_length = len(word) + else: + current_line.append(word) + current_length += word_length + + if current_line: + lines.append(" ".join(current_line)) + + return f"\n{indent}".join(lines) + + # Add the header + formatted_params.append(f"{base_indent}{header}:") + + for param in params: + # Format parameter name and type + type_str = get_type_str(param.type_hint) if param.type_hint != Any else "" + # YiYi Notes: remove this line if we remove kwargs_type + name = f'**{param.kwargs_type}' if param.name is None and param.kwargs_type is not None else param.name + param_str = f"{param_indent}{name} (`{type_str}`" + + # Add optional tag and default value if parameter is an InputParam and optional + if hasattr(param, "required"): + if not param.required: + param_str += ", *optional*" + if param.default is not None: + param_str += f", defaults to {param.default}" + param_str += "):" + + # Add description on a new line with additional indentation and wrapping + if param.description: + desc = re.sub( + r'\[(.*?)\]\((https?://[^\s\)]+)\)', + r'[\1](\2)', + param.description + ) + wrapped_desc = wrap_text(desc, desc_indent, max_line_length) + param_str += f"\n{desc_indent}{wrapped_desc}" + + formatted_params.append(param_str) + + return "\n\n".join(formatted_params) + + +def format_input_params(input_params, indent_level=4, max_line_length=115): + """Format a list of InputParam objects into a readable string representation. + + Args: + input_params: List of InputParam objects to format + indent_level: Number of spaces to indent each parameter line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + + Returns: + A formatted string representing all input parameters + """ + return format_params(input_params, "Inputs", indent_level, max_line_length) + + +def format_output_params(output_params, indent_level=4, max_line_length=115): + """Format a list of OutputParam objects into a readable string representation. + + Args: + output_params: List of OutputParam objects to format + indent_level: Number of spaces to indent each parameter line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + + Returns: + A formatted string representing all output parameters + """ + return format_params(output_params, "Outputs", indent_level, max_line_length) + + +def format_components(components, indent_level=4, max_line_length=115, add_empty_lines=True): + """Format a list of ComponentSpec objects into a readable string representation. + + Args: + components: List of ComponentSpec objects to format + indent_level: Number of spaces to indent each component line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + add_empty_lines: Whether to add empty lines between components (default: True) + + Returns: + A formatted string representing all components + """ + if not components: + return "" + + base_indent = " " * indent_level + component_indent = " " * (indent_level + 4) + formatted_components = [] + + # Add the header + formatted_components.append(f"{base_indent}Components:") + if add_empty_lines: + formatted_components.append("") + + # Add each component with optional empty lines between them + for i, component in enumerate(components): + # Get type name, handling special cases + type_name = component.type_hint.__name__ if hasattr(component.type_hint, "__name__") else str(component.type_hint) + + component_desc = f"{component_indent}{component.name} (`{type_name}`)" + if component.description: + component_desc += f": {component.description}" + + # Get the loading fields dynamically + loading_field_values = [] + for field_name in component.loading_fields(): + field_value = getattr(component, field_name) + if field_value is not None: + loading_field_values.append(f"{field_name}={field_value}") + + # Add loading field information if available + if loading_field_values: + component_desc += f" [{', '.join(loading_field_values)}]" + + formatted_components.append(component_desc) + + # Add an empty line after each component except the last one + if add_empty_lines and i < len(components) - 1: + formatted_components.append("") + + return "\n".join(formatted_components) + + +def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines=True): + """Format a list of ConfigSpec objects into a readable string representation. + + Args: + configs: List of ConfigSpec objects to format + indent_level: Number of spaces to indent each config line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + add_empty_lines: Whether to add empty lines between configs (default: True) + + Returns: + A formatted string representing all configs + """ + if not configs: + return "" + + base_indent = " " * indent_level + config_indent = " " * (indent_level + 4) + formatted_configs = [] + + # Add the header + formatted_configs.append(f"{base_indent}Configs:") + if add_empty_lines: + formatted_configs.append("") + + # Add each config with optional empty lines between them + for i, config in enumerate(configs): + config_desc = f"{config_indent}{config.name} (default: {config.default})" + if config.description: + config_desc += f": {config.description}" + formatted_configs.append(config_desc) + + # Add an empty line after each config except the last one + if add_empty_lines and i < len(configs) - 1: + formatted_configs.append("") + + return "\n".join(formatted_configs) + + +def make_doc_string(inputs, intermediates_inputs, outputs, description="", class_name=None, expected_components=None, expected_configs=None): + """ + Generates a formatted documentation string describing the pipeline block's parameters and structure. + + Args: + inputs: List of input parameters + intermediates_inputs: List of intermediate input parameters + outputs: List of output parameters + description (str, *optional*): Description of the block + class_name (str, *optional*): Name of the class to include in the documentation + expected_components (List[ComponentSpec], *optional*): List of expected components + expected_configs (List[ConfigSpec], *optional*): List of expected configurations + + Returns: + str: A formatted string containing information about components, configs, call parameters, + intermediate inputs/outputs, and final outputs. + """ + output = "" + + # Add class name if provided + if class_name: + output += f"class {class_name}\n\n" + + # Add description + if description: + desc_lines = description.strip().split('\n') + aligned_desc = '\n'.join(' ' + line for line in desc_lines) + output += aligned_desc + "\n\n" + + # Add components section if provided + if expected_components and len(expected_components) > 0: + components_str = format_components(expected_components, indent_level=2) + output += components_str + "\n\n" + + # Add configs section if provided + if expected_configs and len(expected_configs) > 0: + configs_str = format_configs(expected_configs, indent_level=2) + output += configs_str + "\n\n" + + # Add inputs section + output += format_input_params(inputs + intermediates_inputs, indent_level=2) + + # Add outputs section + output += "\n\n" + output += format_output_params(outputs, indent_level=2) + + return output \ No newline at end of file diff --git a/src/diffusers/modular_pipelines/node_utils.py b/src/diffusers/modular_pipelines/node_utils.py new file mode 100644 index 000000000000..5f5e1c6c782d --- /dev/null +++ b/src/diffusers/modular_pipelines/node_utils.py @@ -0,0 +1,519 @@ +from ..configuration_utils import ConfigMixin +from .modular_pipeline import SequentialPipelineBlocks, ModularPipelineBlocks +from .modular_pipeline_utils import InputParam, OutputParam +from ..image_processor import PipelineImageInput +from pathlib import Path +import json +import os + +from typing import Union, List, Optional, Tuple +import torch +import PIL +import numpy as np +import logging +logger = logging.getLogger(__name__) + +# YiYi Notes: this is actually for SDXL, put it here for now +SDXL_INPUTS_SCHEMA = { + "prompt": InputParam("prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"), + "prompt_2": InputParam("prompt_2", type_hint=Union[str, List[str]], description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2"), + "negative_prompt": InputParam("negative_prompt", type_hint=Union[str, List[str]], description="The prompt or prompts not to guide the image generation"), + "negative_prompt_2": InputParam("negative_prompt_2", type_hint=Union[str, List[str]], description="The negative prompt or prompts for text_encoder_2"), + "cross_attention_kwargs": InputParam("cross_attention_kwargs", type_hint=Optional[dict], description="Kwargs dictionary passed to the AttentionProcessor"), + "clip_skip": InputParam("clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"), + "image": InputParam("image", type_hint=PipelineImageInput, required=True, description="The image(s) to modify for img2img or inpainting"), + "mask_image": InputParam("mask_image", type_hint=PipelineImageInput, required=True, description="Mask image for inpainting, white pixels will be repainted"), + "generator": InputParam("generator", type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], description="Generator(s) for deterministic generation"), + "height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"), + "width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"), + "num_images_per_prompt": InputParam("num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"), + "num_inference_steps": InputParam("num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"), + "timesteps": InputParam("timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"), + "sigmas": InputParam("sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"), + "denoising_end": InputParam("denoising_end", type_hint=Optional[float], description="Fraction of denoising process to complete before termination"), + # YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999 + "strength": InputParam("strength", type_hint=float, default=0.3, description="How much to transform the reference image"), + "denoising_start": InputParam("denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"), + "latents": InputParam("latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"), + "padding_mask_crop": InputParam("padding_mask_crop", type_hint=Optional[Tuple[int, int]], description="Size of margin in crop for image and mask"), + "original_size": InputParam("original_size", type_hint=Optional[Tuple[int, int]], description="Original size of the image for SDXL's micro-conditioning"), + "target_size": InputParam("target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"), + "negative_original_size": InputParam("negative_original_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on image resolution"), + "negative_target_size": InputParam("negative_target_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on target resolution"), + "crops_coords_top_left": InputParam("crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Top-left coordinates for SDXL's micro-conditioning"), + "negative_crops_coords_top_left": InputParam("negative_crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Negative conditioning crop coordinates"), + "aesthetic_score": InputParam("aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"), + "negative_aesthetic_score": InputParam("negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"), + "eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"), + "output_type": InputParam("output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"), + "ip_adapter_image": InputParam("ip_adapter_image", type_hint=PipelineImageInput, required=True, description="Image(s) to be used as IP adapter"), + "control_image": InputParam("control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"), + "control_guidance_start": InputParam("control_guidance_start", type_hint=Union[float, List[float]], default=0.0, description="When ControlNet starts applying"), + "control_guidance_end": InputParam("control_guidance_end", type_hint=Union[float, List[float]], default=1.0, description="When ControlNet stops applying"), + "controlnet_conditioning_scale": InputParam("controlnet_conditioning_scale", type_hint=Union[float, List[float]], default=1.0, description="Scale factor for ControlNet outputs"), + "guess_mode": InputParam("guess_mode", type_hint=bool, default=False, description="Enables ControlNet encoder to recognize input without prompts"), + "control_mode": InputParam("control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet") +} + +SDXL_INTERMEDIATE_INPUTS_SCHEMA = { + "prompt_embeds": InputParam("prompt_embeds", type_hint=torch.Tensor, required=True, description="Text embeddings used to guide image generation"), + "negative_prompt_embeds": InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"), + "pooled_prompt_embeds": InputParam("pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"), + "negative_pooled_prompt_embeds": InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"), + "batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"), + "dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), + "preprocess_kwargs": InputParam("preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"), + "latents": InputParam("latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"), + "timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"), + "num_inference_steps": InputParam("num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"), + "latent_timestep": InputParam("latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"), + "image_latents": InputParam("image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"), + "mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"), + "masked_image_latents": InputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"), + "add_time_ids": InputParam("add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"), + "negative_add_time_ids": InputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"), + "timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"), + "noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"), + "crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"), + "ip_adapter_embeds": InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"), + "negative_ip_adapter_embeds": InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"), + "images": InputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], required=True, description="Generated images") +} + +SDXL_PARAM_SCHEMA = {**SDXL_INPUTS_SCHEMA, **SDXL_INTERMEDIATE_INPUTS_SCHEMA} + + +DEFAULT_PARAM_MAPS = { + "prompt": { + "label": "Prompt", + "type": "string", + "default": "a bear sitting in a chair drinking a milkshake", + "display": "textarea", + }, + "negative_prompt": { + "label": "Negative Prompt", + "type": "string", + "default": "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality", + "display": "textarea", + }, + + "num_inference_steps": { + "label": "Steps", + "type": "int", + "default": 25, + "min": 1, + "max": 1000, + }, + "seed": { + "label": "Seed", + "type": "int", + "default": 0, + "min": 0, + "display": "random", + }, + "width": { + "label": "Width", + "type": "int", + "display": "text", + "default": 1024, + "min": 8, + "max": 8192, + "step": 8, + "group": "dimensions", + }, + "height": { + "label": "Height", + "type": "int", + "display": "text", + "default": 1024, + "min": 8, + "max": 8192, + "step": 8, + "group": "dimensions", + }, + "images": { + "label": "Images", + "type": "image", + "display": "output", + }, + "image": { + "label": "Image", + "type": "image", + "display": "input", + }, +} + +DEFAULT_TYPE_MAPS ={ + "int": { + "type": "int", + "default": 0, + "min": 0, + }, + "float": { + "type": "float", + "default": 0.0, + "min": 0.0, + }, + "str": { + "type": "string", + "default": "", + }, + "bool": { + "type": "boolean", + "default": False, + }, + "image": { + "type": "image", + }, +} + +DEFAULT_MODEL_KEYS = ["unet", "vae", "text_encoder", "tokenizer", "controlnet", "transformer", "image_encoder"] +DEFAULT_CATEGORY = "Modular Diffusers" +DEFAULT_EXCLUDE_MODEL_KEYS = ["processor", "feature_extractor", "safety_checker"] +DEFAULT_PARAMS_GROUPS_KEYS = { + "text_encoders": ["text_encoder", "tokenizer"], + "ip_adapter_embeds": ["ip_adapter_embeds"], + "prompt_embeddings": ["prompt_embeds"], +} + + +def get_group_name(name, group_params_keys=DEFAULT_PARAMS_GROUPS_KEYS): + """ + Get the group name for a given parameter name, if not part of a group, return None + e.g. "prompt_embeds" -> "text_embeds", "text_encoder" -> "text_encoders", "prompt" -> None + """ + if name is None: + return None + for group_name, group_keys in group_params_keys.items(): + for group_key in group_keys: + if group_key in name: + return group_name + return None + + +class ModularNode(ConfigMixin): + + config_name = "node_config.json" + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + trust_remote_code: Optional[bool] = None, + **kwargs, + ): + blocks = ModularPipelineBlocks.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs) + return cls(blocks, **kwargs) + + def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs): + self.blocks = blocks + + if label is None: + label = self.blocks.__class__.__name__ + # blocks param name -> mellon param name + self.name_mapping = {} + + input_params = {} + # pass or create a default param dict for each input + # e.g. for prompt, + # prompt = { + # "name": "text_input", # the name of the input in node defination, could be different from the input name in diffusers + # "label": "Prompt", + # "type": "string", + # "default": "a bear sitting in a chair drinking a milkshake", + # "display": "textarea"} + # if type is not specified, it'll be a "custom" param of its own type + # e.g. you can pass ModularNode(scheduler = {name :"scheduler"}) + # it will get this spec in node defination {"scheduler": {"label": "Scheduler", "type": "scheduler", "display": "input"}} + # name can be a dict, in that case, it is part of a "dict" input in mellon nodes, e.g. text_encoder= {name: {"text_encoders": "text_encoder"}} + inputs = self.blocks.inputs + self.blocks.intermediates_inputs + for inp in inputs: + param = kwargs.pop(inp.name, None) + if param: + # user can pass a param dict for all inputs, e.g. ModularNode(prompt = {...}) + input_params[inp.name] = param + mellon_name = param.pop("name", inp.name) + if mellon_name != inp.name: + self.name_mapping[inp.name] = mellon_name + continue + + if not inp.name in DEFAULT_PARAM_MAPS and not inp.required and not get_group_name(inp.name): + continue + + if inp.name in DEFAULT_PARAM_MAPS: + # first check if it's in the default param map, if so, directly use that + param = DEFAULT_PARAM_MAPS[inp.name].copy() + elif get_group_name(inp.name): + param = get_group_name(inp.name) + if inp.name not in self.name_mapping: + self.name_mapping[inp.name] = param + else: + # if not, check if it's in the SDXL input schema, if so, + # 1. use the type hint to determine the type + # 2. use the default param dict for the type e.g. if "steps" is a "int" type, {"steps": {"type": "int", "default": 0, "min": 0}} + if inp.type_hint is not None: + type_str = str(inp.type_hint).lower() + else: + inp_spec = SDXL_PARAM_SCHEMA.get(inp.name, None) + type_str = str(inp_spec.type_hint).lower() if inp_spec else "" + for type_key, type_param in DEFAULT_TYPE_MAPS.items(): + if type_key in type_str: + param = type_param.copy() + param["label"] = inp.name + param["display"] = "input" + break + else: + param = inp.name + # add the param dict to the inp_params dict + input_params[inp.name] = param + + + component_params = {} + for comp in self.blocks.expected_components: + param = kwargs.pop(comp.name, None) + if param: + component_params[comp.name] = param + mellon_name = param.pop("name", comp.name) + if mellon_name != comp.name: + self.name_mapping[comp.name] = mellon_name + continue + + to_exclude = False + for exclude_key in DEFAULT_EXCLUDE_MODEL_KEYS: + if exclude_key in comp.name: + to_exclude = True + break + if to_exclude: + continue + + if get_group_name(comp.name): + param = get_group_name(comp.name) + if comp.name not in self.name_mapping: + self.name_mapping[comp.name] = param + elif comp.name in DEFAULT_MODEL_KEYS: + param = {"label": comp.name, "type": "diffusers_auto_model", "display": "input"} + else: + param = comp.name + # add the param dict to the model_params dict + component_params[comp.name] = param + + output_params = {} + if isinstance(self.blocks, SequentialPipelineBlocks): + last_block_name = list(self.blocks.blocks.keys())[-1] + outputs = self.blocks.blocks[last_block_name].intermediates_outputs + else: + outputs = self.blocks.intermediates_outputs + + for out in outputs: + param = kwargs.pop(out.name, None) + if param: + output_params[out.name] = param + mellon_name = param.pop("name", out.name) + if mellon_name != out.name: + self.name_mapping[out.name] = mellon_name + continue + + if out.name in DEFAULT_PARAM_MAPS: + param = DEFAULT_PARAM_MAPS[out.name].copy() + param["display"] = "output" + else: + group_name = get_group_name(out.name) + if group_name: + param = group_name + if out.name not in self.name_mapping: + self.name_mapping[out.name] = param + else: + param = out.name + # add the param dict to the outputs dict + output_params[out.name] = param + + if len(kwargs) > 0: + logger.warning(f"Unused kwargs: {kwargs}") + + register_dict = { + "category": category, + "label": label, + "input_params": input_params, + "component_params": component_params, + "output_params": output_params, + "name_mapping": self.name_mapping, + } + self.register_to_config(**register_dict) + + def setup(self, components, collection=None): + self.blocks.setup_loader(component_manager=components, collection=collection) + self._components_manager = components + + @property + def mellon_config(self): + return self._convert_to_mellon_config() + + def _convert_to_mellon_config(self): + + node = {} + node["label"] = self.config.label + node["category"] = self.config.category + + node_param = {} + for inp_name, inp_param in self.config.input_params.items(): + if inp_name in self.name_mapping: + mellon_name = self.name_mapping[inp_name] + else: + mellon_name = inp_name + if isinstance(inp_param, str): + param = { + "label": inp_param, + "type": inp_param, + "display": "input", + } + else: + param = inp_param + + if mellon_name not in node_param: + node_param[mellon_name] = param + else: + logger.debug(f"Input param {mellon_name} already exists in node_param, skipping {inp_name}") + + + for comp_name, comp_param in self.config.component_params.items(): + if comp_name in self.name_mapping: + mellon_name = self.name_mapping[comp_name] + else: + mellon_name = comp_name + if isinstance(comp_param, str): + param = { + "label": comp_param, + "type": comp_param, + "display": "input", + } + else: + param = comp_param + + if mellon_name not in node_param: + node_param[mellon_name] = param + else: + logger.debug(f"Component param {comp_param} already exists in node_param, skipping {comp_name}") + + + for out_name, out_param in self.config.output_params.items(): + if out_name in self.name_mapping: + mellon_name = self.name_mapping[out_name] + else: + mellon_name = out_name + if isinstance(out_param, str): + param = { + "label": out_param, + "type": out_param, + "display": "output", + } + else: + param = out_param + + if mellon_name not in node_param: + node_param[mellon_name] = param + else: + logger.debug(f"Output param {out_param} already exists in node_param, skipping {out_name}") + node["params"] = node_param + return node + + def save_mellon_config(self, file_path): + """ + Save the Mellon configuration to a JSON file. + + Args: + file_path (str or Path): Path where the JSON file will be saved + + Returns: + Path: Path to the saved config file + """ + file_path = Path(file_path) + + # Create directory if it doesn't exist + os.makedirs(file_path.parent, exist_ok=True) + + # Create a combined dictionary with module definition and name mapping + config = { + "module": self.mellon_config, + "name_mapping": self.name_mapping + } + + # Save the config to file + with open(file_path, 'w', encoding='utf-8') as f: + json.dump(config, f, indent=2) + + logger.info(f"Mellon config and name mapping saved to {file_path}") + + return file_path + + @classmethod + def load_mellon_config(cls, file_path): + """ + Load a Mellon configuration from a JSON file. + + Args: + file_path (str or Path): Path to the JSON file containing Mellon config + + Returns: + dict: The loaded combined configuration containing 'module' and 'name_mapping' + """ + file_path = Path(file_path) + + if not file_path.exists(): + raise FileNotFoundError(f"Config file not found: {file_path}") + + with open(file_path, 'r', encoding='utf-8') as f: + config = json.load(f) + + logger.info(f"Mellon config loaded from {file_path}") + + + return config + + def process_inputs(self, **kwargs): + + params_components = {} + for comp_name, comp_param in self.config.component_params.items(): + logger.debug(f"component: {comp_name}") + mellon_comp_name = self.name_mapping.get(comp_name, comp_name) + if mellon_comp_name in kwargs: + if isinstance(kwargs[mellon_comp_name], dict) and comp_name in kwargs[mellon_comp_name]: + comp = kwargs[mellon_comp_name].pop(comp_name) + else: + comp = kwargs.pop(mellon_comp_name) + if comp: + params_components[comp_name] = self._components_manager.get_one(comp["model_id"]) + + + params_run = {} + for inp_name, inp_param in self.config.input_params.items(): + logger.debug(f"input: {inp_name}") + mellon_inp_name = self.name_mapping.get(inp_name, inp_name) + if mellon_inp_name in kwargs: + if isinstance(kwargs[mellon_inp_name], dict) and inp_name in kwargs[mellon_inp_name]: + inp = kwargs[mellon_inp_name].pop(inp_name) + else: + inp = kwargs.pop(mellon_inp_name) + if inp is not None: + params_run[inp_name] = inp + + return_output_names = list(self.config.output_params.keys()) + + return params_components, params_run, return_output_names + + def execute(self, **kwargs): + params_components, params_run, return_output_names = self.process_inputs(**kwargs) + + self.blocks.loader.update(**params_components) + output = self.blocks.run(**params_run, output=return_output_names) + return output + + + + + + + + + + + diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py new file mode 100644 index 000000000000..f3f961d61a13 --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py @@ -0,0 +1,51 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modular_pipeline_presets"] = ["StableDiffusionXLAutoPipeline"] + _import_structure["modular_loader"] = ["StableDiffusionXLModularLoader"] + _import_structure["encoders"] = ["StableDiffusionXLAutoIPAdapterStep", "StableDiffusionXLTextEncoderStep", "StableDiffusionXLAutoVaeEncoderStep"] + _import_structure["decoders"] = ["StableDiffusionXLAutoDecodeStep"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .modular_pipeline_presets import StableDiffusionXLAutoPipeline + from .modular_loader import StableDiffusionXLModularLoader + from .encoders import StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoVaeEncoderStep + from .decoders import StableDiffusionXLAutoDecodeStep +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py new file mode 100644 index 000000000000..07f096249c0d --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py @@ -0,0 +1,1764 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, List, Optional, Tuple, Union, Dict + +import PIL +import torch +from collections import OrderedDict + +from ...image_processor import VaeImageProcessor, PipelineImageInput +from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin +from ...models import ControlNetModel, UNet2DConditionModel, AutoencoderKL, ControlNetUnionModel +from ...utils import logging +from ...utils.torch_utils import randn_tensor, unwrap_module + +from ...pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ...pipelines.controlnet.multicontrolnet import MultiControlNetModel +from ...schedulers import EulerDiscreteScheduler +from ...configuration_utils import FrozenDict + +from .modular_loader import StableDiffusionXLModularLoader +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from ..modular_pipeline import ( + AutoPipelineBlocks, + ModularLoader, + PipelineBlock, + PipelineState, + SequentialPipelineBlocks, +) + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + + + + +# TODO(yiyi, aryan): We need another step before text encoder to set the `num_inference_steps` attribute for guider so that +# things like when to do guidance and how many conditions to be prepared can be determined. Currently, this is done by +# always assuming you want to do guidance in the Guiders. So, negative embeddings are prepared regardless of what the +# configuration of guider is. + + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def prepare_latents_img2img(vae, scheduler, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True): + + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + latents_mean = latents_std = None + if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None: + latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None: + latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1) + # make sure the VAE is in float32 mode, as it overflows in float16 + if vae.config.force_upcast: + image = image.float() + vae.to(dtype=torch.float32) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + + init_latents = [ + retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(vae.encode(image), generator=generator) + + if vae.config.force_upcast: + vae.to(dtype) + + init_latents = init_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=device, dtype=dtype) + latents_std = latents_std.to(device=device, dtype=dtype) + init_latents = (init_latents - latents_mean) * vae.config.scaling_factor / latents_std + else: + init_latents = vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + if add_noise: + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # get latents + init_latents = scheduler.add_noise(init_latents, noise, timestep) + + latents = init_latents + + return latents + + +class StableDiffusionXLInputStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return ( + "Input processing step that:\n" + " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" + " 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`\n\n" + "All input tensors are expected to have either batch_size=1 or match the batch_size\n" + "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n" + "have a final batch_size of batch_size * num_images_per_prompt." + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_images_per_prompt", default=1), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam("prompt_embeds", required=True, type_hint=torch.Tensor, description="Pre-generated text embeddings. Can be generated from text_encoder step."), + InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Pre-generated negative text embeddings. Can be generated from text_encoder step."), + InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor, description="Pre-generated pooled text embeddings. Can be generated from text_encoder step."), + InputParam("negative_pooled_prompt_embeds", description="Pre-generated negative pooled text embeddings. Can be generated from text_encoder step."), + InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Pre-generated image embeddings for IP-Adapter. Can be generated from ip_adapter step."), + InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Pre-generated negative image embeddings for IP-Adapter. Can be generated from ip_adapter step."), + ] + + @property + def intermediates_outputs(self) -> List[str]: + return [ + OutputParam("batch_size", type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt"), + OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs (determined by `prompt_embeds`)"), + OutputParam("prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="text embeddings used to guide the image generation"), + OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative text embeddings used to guide the image generation"), + OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="pooled text embeddings used to guide the image generation"), + OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative pooled text embeddings used to guide the image generation"), + OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], kwargs_type="guider_input_fields", description="image embeddings for IP-Adapter"), + OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], kwargs_type="guider_input_fields", description="negative image embeddings for IP-Adapter"), + ] + + def check_inputs(self, components, block_state): + + if block_state.prompt_embeds is not None and block_state.negative_prompt_embeds is not None: + if block_state.prompt_embeds.shape != block_state.negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `negative_prompt_embeds`" + f" {block_state.negative_prompt_embeds.shape}." + ) + + if block_state.prompt_embeds is not None and block_state.pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if block_state.negative_prompt_embeds is not None and block_state.negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if block_state.ip_adapter_embeds is not None and not isinstance(block_state.ip_adapter_embeds, list): + raise ValueError("`ip_adapter_embeds` must be a list") + + if block_state.negative_ip_adapter_embeds is not None and not isinstance(block_state.negative_ip_adapter_embeds, list): + raise ValueError("`negative_ip_adapter_embeds` must be a list") + + if block_state.ip_adapter_embeds is not None and block_state.negative_ip_adapter_embeds is not None: + for i, ip_adapter_embed in enumerate(block_state.ip_adapter_embeds): + if ip_adapter_embed.shape != block_state.negative_ip_adapter_embeds[i].shape: + raise ValueError( + "`ip_adapter_embeds` and `negative_ip_adapter_embeds` must have the same shape when passed directly, but" + f" got: `ip_adapter_embeds` {ip_adapter_embed.shape} != `negative_ip_adapter_embeds`" + f" {block_state.negative_ip_adapter_embeds[i].shape}." + ) + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + block_state.batch_size = block_state.prompt_embeds.shape[0] + block_state.dtype = block_state.prompt_embeds.dtype + + _, seq_len, _ = block_state.prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.prompt_embeds = block_state.prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1) + + if block_state.negative_prompt_embeds is not None: + _, seq_len, _ = block_state.negative_prompt_embeds.shape + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1) + + block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1) + + if block_state.negative_pooled_prompt_embeds is not None: + block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1) + + if block_state.ip_adapter_embeds is not None: + for i, ip_adapter_embed in enumerate(block_state.ip_adapter_embeds): + block_state.ip_adapter_embeds[i] = torch.cat([ip_adapter_embed] * block_state.num_images_per_prompt, dim=0) + + if block_state.negative_ip_adapter_embeds is not None: + for i, negative_ip_adapter_embed in enumerate(block_state.negative_ip_adapter_embeds): + block_state.negative_ip_adapter_embeds[i] = torch.cat([negative_ip_adapter_embed] * block_state.num_images_per_prompt, dim=0) + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return ( + "Step that sets the timesteps for the scheduler and determines the initial noise level (latent_timestep) for image-to-image/inpainting generation.\n" + \ + "The latent_timestep is calculated from the `strength` parameter - higher strength means starting from a noisier version of the input image." + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_inference_steps", default=50), + InputParam("timesteps"), + InputParam("sigmas"), + InputParam("denoising_end"), + InputParam("strength", default=0.3), + InputParam("denoising_start"), + # YiYi TODO: do we need num_images_per_prompt here? + InputParam("num_images_per_prompt", default=1), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt"), + ] + + @property + def intermediates_outputs(self) -> List[str]: + return [ + OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), + OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time"), + OutputParam("latent_timestep", type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image generation") + ] + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps with self -> components + def get_timesteps(self, components, num_inference_steps, strength, device, denoising_start=None): + # get the original timestep using init_timestep + if denoising_start is None: + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + + timesteps = components.scheduler.timesteps[t_start * components.scheduler.order :] + if hasattr(components.scheduler, "set_begin_index"): + components.scheduler.set_begin_index(t_start * components.scheduler.order) + + return timesteps, num_inference_steps - t_start + + else: + # Strength is irrelevant if we directly request a timestep to start at; + # that is, strength is determined by the denoising_start instead. + discrete_timestep_cutoff = int( + round( + components.scheduler.config.num_train_timesteps + - (denoising_start * components.scheduler.config.num_train_timesteps) + ) + ) + + num_inference_steps = (components.scheduler.timesteps < discrete_timestep_cutoff).sum().item() + if components.scheduler.order == 2 and num_inference_steps % 2 == 0: + # if the scheduler is a 2nd order scheduler we might have to do +1 + # because `num_inference_steps` might be even given that every timestep + # (except the highest one) is duplicated. If `num_inference_steps` is even it would + # mean that we cut the timesteps in the middle of the denoising step + # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 + # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler + num_inference_steps = num_inference_steps + 1 + + # because t_n+1 >= t_n, we slice the timesteps starting from the end + t_start = len(components.scheduler.timesteps) - num_inference_steps + timesteps = components.scheduler.timesteps[t_start:] + if hasattr(components.scheduler, "set_begin_index"): + components.scheduler.set_begin_index(t_start) + return timesteps, num_inference_steps + + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.device = components._execution_device + + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + components.scheduler, block_state.num_inference_steps, block_state.device, block_state.timesteps, block_state.sigmas + ) + + def denoising_value_valid(dnv): + return isinstance(dnv, float) and 0 < dnv < 1 + + block_state.timesteps, block_state.num_inference_steps = self.get_timesteps( + components, + block_state.num_inference_steps, + block_state.strength, + block_state.device, + denoising_start=block_state.denoising_start if denoising_value_valid(block_state.denoising_start) else None, + ) + block_state.latent_timestep = block_state.timesteps[:1].repeat(block_state.batch_size * block_state.num_images_per_prompt) + + if block_state.denoising_end is not None and isinstance(block_state.denoising_end, float) and block_state.denoising_end > 0 and block_state.denoising_end < 1: + block_state.discrete_timestep_cutoff = int( + round( + components.scheduler.config.num_train_timesteps + - (block_state.denoising_end * components.scheduler.config.num_train_timesteps) + ) + ) + block_state.num_inference_steps = len(list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps))) + block_state.timesteps = block_state.timesteps[:block_state.num_inference_steps] + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLSetTimestepsStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return ( + "Step that sets the scheduler's timesteps for inference" + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_inference_steps", default=50), + InputParam("timesteps"), + InputParam("sigmas"), + InputParam("denoising_end"), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), + OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time")] + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.device = components._execution_device + + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + components.scheduler, block_state.num_inference_steps, block_state.device, block_state.timesteps, block_state.sigmas + ) + + if block_state.denoising_end is not None and isinstance(block_state.denoising_end, float) and block_state.denoising_end > 0 and block_state.denoising_end < 1: + block_state.discrete_timestep_cutoff = int( + round( + components.scheduler.config.num_train_timesteps + - (block_state.denoising_end * components.scheduler.config.num_train_timesteps) + ) + ) + block_state.num_inference_steps = len(list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps))) + block_state.timesteps = block_state.timesteps[:block_state.num_inference_steps] + + self.add_block_state(state, block_state) + return components, state + + +class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return ( + "Step that prepares the latents for the inpainting process" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("latents"), + InputParam("num_images_per_prompt", default=1), + InputParam("denoising_start"), + InputParam( + "strength", + default=0.9999, + description="Conceptually, indicates how much to transform the reference `image` (the masked portion of image for inpainting). Must be between 0 and 1. `image` " + "will be used as a starting point, adding more noise to it the larger the `strength`. The number of " + "denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will " + "be maximum and the denoising process will run for the full number of iterations specified in " + "`num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of " + "`denoising_start` being declared as an integer, the value of `strength` will be ignored." + ), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam("generator"), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + ), + InputParam( + "latent_timestep", + required=True, + type_hint=torch.Tensor, + description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step." + ), + InputParam( + "image_latents", + required=True, + type_hint=torch.Tensor, + description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step." + ), + InputParam( + "mask", + required=True, + type_hint=torch.Tensor, + description="The mask for the inpainting generation. Can be generated in vae_encode step." + ), + InputParam( + "masked_image_latents", + type_hint=torch.Tensor, + description="The masked image latents for the inpainting generation (only for inpainting-specific unet). Can be generated in vae_encode step." + ), + InputParam( + "dtype", + type_hint=torch.dtype, + description="The dtype of the model inputs" + ) + ] + + @property + def intermediates_outputs(self) -> List[str]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"), + OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for inpainting generation"), + OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting generation (only for inpainting-specific unet)"), + OutputParam("noise", type_hint=torch.Tensor, description="The noise added to the image latents, used for inpainting generation")] + + + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components + # YiYi TODO: update the _encode_vae_image so that we can use #Coped from + @staticmethod + def _encode_vae_image(components, image: torch.Tensor, generator: torch.Generator): + + latents_mean = latents_std = None + if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: + latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: + latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) + + dtype = image.dtype + if components.vae.config.force_upcast: + image = image.float() + components.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(components.vae.encode(image), generator=generator) + + if components.vae.config.force_upcast: + components.vae.to(dtype) + + image_latents = image_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) + latents_std = latents_std.to(device=image_latents.device, dtype=dtype) + image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std + else: + image_latents = components.vae.config.scaling_factor * image_latents + + return image_latents + + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents adding components as first argument + def prepare_latents_inpaint( + self, + components, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + add_noise=True, + return_noise=False, + return_image_latents=False, + ): + shape = ( + batch_size, + num_channels_latents, + int(height) // components.vae_scale_factor, + int(width) // components.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if image.shape[1] == 4: + image_latents = image.to(device=device, dtype=dtype) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + elif return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(components, image=image, generator=generator) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + + if latents is None and add_noise: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else components.scheduler.add_noise(image_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * components.scheduler.init_noise_sigma if is_strength_max else latents + elif add_noise: + noise = latents.to(device) + latents = noise * components.scheduler.init_noise_sigma + else: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = image_latents.to(device) + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents + # do not accept do_classifier_free_guidance + def prepare_mask_latents( + self, components, mask, masked_image, batch_size, height, width, dtype, device, generator + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + + if masked_image is not None and masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + masked_image_latents = None + + if masked_image is not None: + if masked_image_latents is None: + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator) + + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat( + batch_size // masked_image_latents.shape[0], 1, 1, 1 + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + return mask, masked_image_latents + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + block_state.device = components._execution_device + + block_state.is_strength_max = block_state.strength == 1.0 + + # for non-inpainting specific unet, we do not need masked_image_latents + if hasattr(components,"unet") and components.unet is not None: + if components.unet.config.in_channels == 4: + block_state.masked_image_latents = None + + block_state.add_noise = True if block_state.denoising_start is None else False + + block_state.height = block_state.image_latents.shape[-2] * components.vae_scale_factor + block_state.width = block_state.image_latents.shape[-1] * components.vae_scale_factor + + block_state.latents, block_state.noise = self.prepare_latents_inpaint( + components, + block_state.batch_size * block_state.num_images_per_prompt, + components.num_channels_latents, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.latents, + image=block_state.image_latents, + timestep=block_state.latent_timestep, + is_strength_max=block_state.is_strength_max, + add_noise=block_state.add_noise, + return_noise=True, + return_image_latents=False, + ) + + # 7. Prepare mask latent variables + block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents( + components, + block_state.mask, + block_state.masked_image_latents, + block_state.batch_size * block_state.num_images_per_prompt, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + ) + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return ( + "Step that prepares the latents for the image-to-image generation process" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("latents"), + InputParam("num_images_per_prompt", default=1), + InputParam("denoising_start"), + ] + + @property + def intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam("generator"), + InputParam("latent_timestep", required=True, type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step."), + InputParam("image_latents", required=True, type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step."), + InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."), + InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs")] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process")] + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + block_state.device = components._execution_device + block_state.add_noise = True if block_state.denoising_start is None else False + if block_state.latents is None: + block_state.latents = prepare_latents_img2img( + components.vae, + components.scheduler, + block_state.image_latents, + block_state.latent_timestep, + block_state.batch_size, + block_state.num_images_per_prompt, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.add_noise, + ) + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLPrepareLatentsStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return ( + "Prepare latents step that prepares the latents for the text-to-image generation process" + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("height"), + InputParam("width"), + InputParam("latents"), + InputParam("num_images_per_prompt", default=1), + ] + + @property + def intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam("generator"), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + ), + InputParam( + "dtype", + type_hint=torch.dtype, + description="The dtype of the model inputs" + ) + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "latents", + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process" + ) + ] + + + @staticmethod + def check_inputs(components, block_state): + if ( + block_state.height is not None + and block_state.height % components.vae_scale_factor != 0 + or block_state.width is not None + and block_state.width % components.vae_scale_factor != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {components.vae_scale_factor} but are {block_state.height} and {block_state.width}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with self -> components + @staticmethod + def prepare_latents(components, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // components.vae_scale_factor, + int(width) // components.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * components.scheduler.init_noise_sigma + return latents + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + if block_state.dtype is None: + block_state.dtype = components.vae.dtype + + block_state.device = components._execution_device + + self.check_inputs(components, block_state) + + block_state.height = block_state.height or components.default_sample_size * components.vae_scale_factor + block_state.width = block_state.width or components.default_sample_size * components.vae_scale_factor + block_state.num_channels_latents = components.num_channels_latents + block_state.latents = self.prepare_latents( + components, + block_state.batch_size * block_state.num_images_per_prompt, + block_state.num_channels_latents, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.latents, + ) + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_configs(self) -> List[ConfigSpec]: + return [ConfigSpec("requires_aesthetics_score", False),] + + @property + def description(self) -> str: + return ( + "Step that prepares the additional conditioning for the image-to-image/inpainting generation process" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("original_size"), + InputParam("target_size"), + InputParam("negative_original_size"), + InputParam("negative_target_size"), + InputParam("crops_coords_top_left", default=(0, 0)), + InputParam("negative_crops_coords_top_left", default=(0, 0)), + InputParam("num_images_per_prompt", default=1), + InputParam("aesthetic_score", default=6.0), + InputParam("negative_aesthetic_score", default=2.0), + ] + + @property + def intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam("latents", required=True, type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."), + InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor, description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step."), + InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The time ids to condition the denoising process"), + OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process"), + OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids with self -> components + @staticmethod + def _get_add_time_ids_img2img( + components, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype, + text_encoder_projection_dim=None, + ): + if components.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list( + negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) + ) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) + + passed_add_embed_dim = ( + components.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = components.unet.add_embedding.linear_1.in_features + + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == components.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == components.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + @staticmethod + def get_guidance_scale_embedding( + w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + + block_state.vae_scale_factor = components.vae_scale_factor + + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * block_state.vae_scale_factor + block_state.width = block_state.width * block_state.vae_scale_factor + + block_state.original_size = block_state.original_size or (block_state.height, block_state.width) + block_state.target_size = block_state.target_size or (block_state.height, block_state.width) + + block_state.text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1]) + + if block_state.negative_original_size is None: + block_state.negative_original_size = block_state.original_size + if block_state.negative_target_size is None: + block_state.negative_target_size = block_state.target_size + + block_state.add_time_ids, block_state.negative_add_time_ids = self._get_add_time_ids_img2img( + components, + block_state.original_size, + block_state.crops_coords_top_left, + block_state.target_size, + block_state.aesthetic_score, + block_state.negative_aesthetic_score, + block_state.negative_original_size, + block_state.negative_crops_coords_top_left, + block_state.negative_target_size, + dtype=block_state.pooled_prompt_embeds.dtype, + text_encoder_projection_dim=block_state.text_encoder_projection_dim, + ) + block_state.add_time_ids = block_state.add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) + block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) + + # Optionally get Guidance Scale Embedding for LCM + block_state.timestep_cond = None + if ( + hasattr(components, "unet") + and components.unet is not None + and components.unet.config.time_cond_proj_dim is not None + ): + # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! + block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat(block_state.batch_size * block_state.num_images_per_prompt) + block_state.timestep_cond = self.get_guidance_scale_embedding( + block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim + ).to(device=block_state.device, dtype=block_state.latents.dtype) + + self.add_block_state(state, block_state) + return components, state + + +class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return ( + "Step that prepares the additional conditioning for the text-to-image generation process" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("original_size"), + InputParam("target_size"), + InputParam("negative_original_size"), + InputParam("negative_target_size"), + InputParam("crops_coords_top_left", default=(0, 0)), + InputParam("negative_crops_coords_top_left", default=(0, 0)), + InputParam("num_images_per_prompt", default=1), + ] + + @property + def intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + InputParam( + "pooled_prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step." + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The time ids to condition the denoising process"), + OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process"), + OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids with self -> components + @staticmethod + def _get_add_time_ids( + components, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + components.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = components.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + @staticmethod + def get_guidance_scale_embedding( + w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * components.vae_scale_factor + block_state.width = block_state.width * components.vae_scale_factor + + block_state.original_size = block_state.original_size or (block_state.height, block_state.width) + block_state.target_size = block_state.target_size or (block_state.height, block_state.width) + + block_state.text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1]) + + block_state.add_time_ids = self._get_add_time_ids( + components, + block_state.original_size, + block_state.crops_coords_top_left, + block_state.target_size, + block_state.pooled_prompt_embeds.dtype, + text_encoder_projection_dim=block_state.text_encoder_projection_dim, + ) + if block_state.negative_original_size is not None and block_state.negative_target_size is not None: + block_state.negative_add_time_ids = self._get_add_time_ids( + components, + block_state.negative_original_size, + block_state.negative_crops_coords_top_left, + block_state.negative_target_size, + block_state.pooled_prompt_embeds.dtype, + text_encoder_projection_dim=block_state.text_encoder_projection_dim, + ) + else: + block_state.negative_add_time_ids = block_state.add_time_ids + + block_state.add_time_ids = block_state.add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) + block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) + + # Optionally get Guidance Scale Embedding for LCM + block_state.timestep_cond = None + if ( + hasattr(components, "unet") + and components.unet is not None + and components.unet.config.time_cond_proj_dim is not None + ): + # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! + block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat(block_state.batch_size * block_state.num_images_per_prompt) + block_state.timestep_cond = self.get_guidance_scale_embedding( + block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim + ).to(device=block_state.device, dtype=block_state.latents.dtype) + + self.add_block_state(state, block_state) + return components, state + + +class StableDiffusionXLControlNetInputStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("controlnet", ControlNetModel), + ComponentSpec("control_image_processor", VaeImageProcessor, config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), default_creation_method="from_config"), + ] + + @property + def description(self) -> str: + return "step that prepare inputs for controlnet" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("control_image", required=True), + InputParam("control_guidance_start", default=0.0), + InputParam("control_guidance_end", default=1.0), + InputParam("controlnet_conditioning_scale", default=1.0), + InputParam("guess_mode", default=False), + InputParam("num_images_per_prompt", default=1), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + ), + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + "crops_coords", + type_hint=Optional[Tuple[int]], + description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [ + OutputParam("controlnet_cond", type_hint=torch.Tensor, description="The processed control image"), + OutputParam("control_guidance_start", type_hint=List[float], description="The controlnet guidance start values"), + OutputParam("control_guidance_end", type_hint=List[float], description="The controlnet guidance end values"), + OutputParam("conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values"), + OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"), + OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"), + ] + + + + # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image + # 1. return image without apply any guidance + # 2. add crops_coords and resize_mode to preprocess() + @staticmethod + def prepare_control_image( + components, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + crops_coords=None, + ): + if crops_coords is not None: + image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) + else: + image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + + image_batch_size = image.shape[0] + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + image = image.to(device=device, dtype=dtype) + return image + + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + + block_state = self.get_block_state(state) + + # (1) prepare controlnet inputs + block_state.device = components._execution_device + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * components.vae_scale_factor + block_state.width = block_state.width * components.vae_scale_factor + + controlnet = unwrap_module(components.controlnet) + + # (1.1) + # control_guidance_start/control_guidance_end (align format) + if not isinstance(block_state.control_guidance_start, list) and isinstance(block_state.control_guidance_end, list): + block_state.control_guidance_start = len(block_state.control_guidance_end) * [block_state.control_guidance_start] + elif not isinstance(block_state.control_guidance_end, list) and isinstance(block_state.control_guidance_start, list): + block_state.control_guidance_end = len(block_state.control_guidance_start) * [block_state.control_guidance_end] + elif not isinstance(block_state.control_guidance_start, list) and not isinstance(block_state.control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + block_state.control_guidance_start, block_state.control_guidance_end = ( + mult * [block_state.control_guidance_start], + mult * [block_state.control_guidance_end], + ) + + # (1.2) + # controlnet_conditioning_scale (align format) + if isinstance(controlnet, MultiControlNetModel) and isinstance(block_state.controlnet_conditioning_scale, float): + block_state.controlnet_conditioning_scale = [block_state.controlnet_conditioning_scale] * len(controlnet.nets) + + # (1.3) + # global_pool_conditions + block_state.global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + # (1.4) + # guess_mode + block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions + + # (1.5) + # control_image + if isinstance(controlnet, ControlNetModel): + block_state.control_image = self.prepare_control_image( + components, + image=block_state.control_image, + width=block_state.width, + height=block_state.height, + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_images_per_prompt=block_state.num_images_per_prompt, + device=block_state.device, + dtype=controlnet.dtype, + crops_coords=block_state.crops_coords, + ) + elif isinstance(controlnet, MultiControlNetModel): + control_images = [] + + for control_image_ in block_state.control_image: + control_image = self.prepare_control_image( + components, + image=control_image_, + width=block_state.width, + height=block_state.height, + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_images_per_prompt=block_state.num_images_per_prompt, + device=block_state.device, + dtype=controlnet.dtype, + crops_coords=block_state.crops_coords, + ) + + control_images.append(control_image) + + block_state.control_image = control_images + else: + assert False + + # (1.6) + # controlnet_keep + block_state.controlnet_keep = [] + for i in range(len(block_state.timesteps)): + keeps = [ + 1.0 - float(i / len(block_state.timesteps) < s or (i + 1) / len(block_state.timesteps) > e) + for s, e in zip(block_state.control_guidance_start, block_state.control_guidance_end) + ] + block_state.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + block_state.controlnet_cond = block_state.control_image + block_state.conditioning_scale = block_state.controlnet_conditioning_scale + + + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLControlNetUnionInputStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("controlnet", ControlNetUnionModel), + ComponentSpec("control_image_processor", VaeImageProcessor, config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), default_creation_method="from_config"), + ] + + @property + def description(self) -> str: + return "step that prepares inputs for the ControlNetUnion model" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("control_image", required=True), + InputParam("control_mode", required=True), + InputParam("control_guidance_start", default=0.0), + InputParam("control_guidance_end", default=1.0), + InputParam("controlnet_conditioning_scale", default=1.0), + InputParam("guess_mode", default=False), + InputParam("num_images_per_prompt", default=1), + ] + + @property + def intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Used to determine the shape of the control images. Can be generated in prepare_latent step." + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + ), + InputParam( + "dtype", + required=True, + type_hint=torch.dtype, + description="The dtype of model tensor inputs. Can be generated in input step." + ), + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Needed to determine `controlnet_keep`. Can be generated in set_timesteps step." + ), + InputParam( + "crops_coords", + type_hint=Optional[Tuple[int]], + description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [ + OutputParam("controlnet_cond", type_hint=List[torch.Tensor], description="The processed control images"), + OutputParam("control_type_idx", type_hint=List[int], description="The control mode indices", kwargs_type="controlnet_kwargs"), + OutputParam("control_type", type_hint=torch.Tensor, description="The control type tensor that specifies which control type is active", kwargs_type="controlnet_kwargs"), + OutputParam("control_guidance_start", type_hint=float, description="The controlnet guidance start value"), + OutputParam("control_guidance_end", type_hint=float, description="The controlnet guidance end value"), + OutputParam("conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values"), + OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"), + OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"), + ] + + # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image + # 1. return image without apply any guidance + # 2. add crops_coords and resize_mode to preprocess() + @staticmethod + def prepare_control_image( + components, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + crops_coords=None, + ): + if crops_coords is not None: + image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) + else: + image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + + image_batch_size = image.shape[0] + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + image = image.to(device=device, dtype=dtype) + return image + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + + block_state = self.get_block_state(state) + + controlnet = unwrap_module(components.controlnet) + + device = components._execution_device + dtype = block_state.dtype or components.controlnet.dtype + + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * components.vae_scale_factor + block_state.width = block_state.width * components.vae_scale_factor + + + # control_guidance_start/control_guidance_end (align format) + if not isinstance(block_state.control_guidance_start, list) and isinstance(block_state.control_guidance_end, list): + block_state.control_guidance_start = len(block_state.control_guidance_end) * [block_state.control_guidance_start] + elif not isinstance(block_state.control_guidance_end, list) and isinstance(block_state.control_guidance_start, list): + block_state.control_guidance_end = len(block_state.control_guidance_start) * [block_state.control_guidance_end] + + # guess_mode + block_state.global_pool_conditions = controlnet.config.global_pool_conditions + block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions + + # control_image + if not isinstance(block_state.control_image, list): + block_state.control_image = [block_state.control_image] + # control_mode + if not isinstance(block_state.control_mode, list): + block_state.control_mode = [block_state.control_mode] + + if len(block_state.control_image) != len(block_state.control_mode): + raise ValueError("Expected len(control_image) == len(control_type)") + + # control_type + block_state.num_control_type = controlnet.config.num_control_type + block_state.control_type = [0 for _ in range(block_state.num_control_type)] + for control_idx in block_state.control_mode: + block_state.control_type[control_idx] = 1 + block_state.control_type = torch.Tensor(block_state.control_type) + + block_state.control_type = block_state.control_type.reshape(1, -1).to(device, dtype=block_state.dtype) + repeat_by = block_state.batch_size * block_state.num_images_per_prompt // block_state.control_type.shape[0] + block_state.control_type = block_state.control_type.repeat_interleave(repeat_by, dim=0) + + # prepare control_image + for idx, _ in enumerate(block_state.control_image): + block_state.control_image[idx] = self.prepare_control_image( + components, + image=block_state.control_image[idx], + width=block_state.width, + height=block_state.height, + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_images_per_prompt=block_state.num_images_per_prompt, + device=device, + dtype=dtype, + crops_coords=block_state.crops_coords, + ) + block_state.height, block_state.width = block_state.control_image[idx].shape[-2:] + + # controlnet_keep + block_state.controlnet_keep = [] + for i in range(len(block_state.timesteps)): + block_state.controlnet_keep.append( + 1.0 + - float(i / len(block_state.timesteps) < block_state.control_guidance_start or (i + 1) / len(block_state.timesteps) > block_state.control_guidance_end) + ) + block_state.control_type_idx = block_state.control_mode + block_state.controlnet_cond = block_state.control_image + block_state.conditioning_scale = block_state.controlnet_conditioning_scale + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLControlNetAutoInput(AutoPipelineBlocks): + + block_classes = [StableDiffusionXLControlNetUnionInputStep, StableDiffusionXLControlNetInputStep] + block_names = ["controlnet_union", "controlnet"] + block_trigger_inputs = ["control_mode", "control_image"] + + + +# Before denoise +class StableDiffusionXLBeforeDenoiseStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLInputStep, StableDiffusionXLSetTimestepsStep, StableDiffusionXLPrepareLatentsStep, StableDiffusionXLPrepareAdditionalConditioningStep, StableDiffusionXLControlNetAutoInput] + block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond", "controlnet_input"] + + @property + def description(self): + return "Before denoise step that prepare the inputs for the denoise step.\n" + \ + "This is a sequential pipeline blocks:\n" + \ + " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ + " - `StableDiffusionXLSetTimestepsStep` is used to set the timesteps\n" + \ + " - `StableDiffusionXLPrepareLatentsStep` is used to prepare the latents\n" + \ + " - `StableDiffusionXLPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + \ + " - `StableDiffusionXLControlNetAutoInput` is used to prepare the controlnet input" + + +class StableDiffusionXLImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLInputStep, StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLImg2ImgPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, StableDiffusionXLControlNetAutoInput] + block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond", "controlnet_input"] + + @property + def description(self): + return "Before denoise step that prepare the inputs for the denoise step for img2img task.\n" + \ + "This is a sequential pipeline blocks:\n" + \ + " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ + " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" + \ + " - `StableDiffusionXLImg2ImgPrepareLatentsStep` is used to prepare the latents\n" + \ + " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + \ + " - `StableDiffusionXLControlNetAutoInput` is used to prepare the controlnet input" + + +class StableDiffusionXLInpaintBeforeDenoiseStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLInputStep, StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLInpaintPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, StableDiffusionXLControlNetAutoInput] + block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond", "controlnet_input"] + + @property + def description(self): + return "Before denoise step that prepare the inputs for the denoise step for inpainting task.\n" + \ + "This is a sequential pipeline blocks:\n" + \ + " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ + " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" + \ + " - `StableDiffusionXLInpaintPrepareLatentsStep` is used to prepare the latents\n" + \ + " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + \ + " - `StableDiffusionXLControlNetAutoInput` is used to prepare the controlnet input" + + +class StableDiffusionXLAutoBeforeDenoiseStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintBeforeDenoiseStep, StableDiffusionXLImg2ImgBeforeDenoiseStep, StableDiffusionXLBeforeDenoiseStep] + block_names = ["inpaint", "img2img", "text2img"] + block_trigger_inputs = ["mask", "image_latents", None] + + @property + def description(self): + return "Before denoise step that prepare the inputs for the denoise step.\n" + \ + "This is an auto pipeline block that works for text2img, img2img and inpainting tasks as well as controlnet, controlnet_union.\n" + \ + " - `StableDiffusionXLInpaintBeforeDenoiseStep` (inpaint) is used when both `mask` and `image_latents` are provided.\n" + \ + " - `StableDiffusionXLImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n" + \ + " - `StableDiffusionXLBeforeDenoiseStep` (text2img) is used when both `image_latents` and `mask` are not provided.\n" + \ + " - `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided.\n" + \ + " - `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided." + diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py new file mode 100644 index 000000000000..ca848e20984f --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py @@ -0,0 +1,215 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, List, Optional, Tuple, Union, Dict + +import PIL +import torch +import numpy as np +from collections import OrderedDict + +from ...image_processor import VaeImageProcessor, PipelineImageInput +from ...models import AutoencoderKL +from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor +from ...utils import logging + +from ...pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from ...configuration_utils import FrozenDict + +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from ..modular_pipeline import ( + AutoPipelineBlocks, + PipelineBlock, + PipelineState, + SequentialPipelineBlocks, +) + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + + + +class StableDiffusionXLDecodeStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config"), + ] + + @property + def description(self) -> str: + return "Step that decodes the denoised latents into images" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("output_type", default="pil"), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [InputParam("latents", required=True, type_hint=torch.Tensor, description="The denoised latents from the denoising step")] + + @property + def intermediates_outputs(self) -> List[str]: + return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array")] + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae with self -> components + @staticmethod + def upcast_vae(components): + dtype = components.vae.dtype + components.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + components.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + components.vae.post_quant_conv.to(dtype) + components.vae.decoder.conv_in.to(dtype) + components.vae.decoder.mid_block.to(dtype) + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + if not block_state.output_type == "latent": + latents = block_state.latents + # make sure the VAE is in float32 mode, as it overflows in float16 + block_state.needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast + + if block_state.needs_upcasting: + self.upcast_vae(components) + latents = latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != components.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + components.vae = components.vae.to(latents.dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + block_state.has_latents_mean = ( + hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None + ) + block_state.has_latents_std = ( + hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None + ) + if block_state.has_latents_mean and block_state.has_latents_std: + block_state.latents_mean = ( + torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + block_state.latents_std = ( + torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean + else: + latents = latents / components.vae.config.scaling_factor + + block_state.images = components.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if block_state.needs_upcasting: + components.vae.to(dtype=torch.float16) + else: + block_state.images = block_state.latents + + # apply watermark if available + if hasattr(components, "watermark") and components.watermark is not None: + block_state.images = components.watermark.apply_watermark(block_state.images) + + block_state.images = components.image_processor.postprocess(block_state.images, output_type=block_state.output_type) + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return "A post-processing step that overlays the mask on the image (inpainting task only).\n" + \ + "only needed when you are using the `padding_mask_crop` option when pre-processing the image and mask" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("image", required=True), + InputParam("mask_image", required=True), + InputParam("padding_mask_crop"), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam("images", required=True, type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images from the decode step"), + InputParam("crops_coords", required=True, type_hint=Tuple[int, int], description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.") + ] + + @property + def intermediates_outputs(self) -> List[str]: + return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images with the mask overlayed")] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + if block_state.padding_mask_crop is not None and block_state.crops_coords is not None: + block_state.images = [components.image_processor.apply_overlay(block_state.mask_image, block_state.image, i, block_state.crops_coords) for i in block_state.images] + + self.add_block_state(state, block_state) + + return components, state + + + +class StableDiffusionXLInpaintDecodeStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLDecodeStep, StableDiffusionXLInpaintOverlayMaskStep] + block_names = ["decode", "mask_overlay"] + + @property + def description(self): + return "Inpaint decode step that decode the denoised latents into images outputs.\n" + \ + "This is a sequential pipeline blocks:\n" + \ + " - `StableDiffusionXLDecodeStep` is used to decode the denoised latents into images\n" + \ + " - `StableDiffusionXLInpaintOverlayMaskStep` is used to overlay the mask on the image" + + +class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintDecodeStep, StableDiffusionXLDecodeStep] + block_names = ["inpaint", "non-inpaint"] + block_trigger_inputs = ["padding_mask_crop", None] + + @property + def description(self): + return "Decode step that decode the denoised latents into images outputs.\n" + \ + "This is an auto pipeline block that works for inpainting and non-inpainting tasks.\n" + \ + " - `StableDiffusionXLInpaintDecodeStep` (inpaint) is used when `padding_mask_crop` is provided.\n" + \ + " - `StableDiffusionXLDecodeStep` (non-inpaint) is used when `padding_mask_crop` is not provided." + + diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py new file mode 100644 index 000000000000..4d7ab12cf009 --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py @@ -0,0 +1,1392 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from tqdm.auto import tqdm + +from ...configuration_utils import FrozenDict +from ...models import ControlNetModel, UNet2DConditionModel +from ...schedulers import EulerDiscreteScheduler +from ...utils import logging +from ...utils.torch_utils import unwrap_module + +from ...guiders import ClassifierFreeGuidance +from .modular_loader import StableDiffusionXLModularLoader +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from ..modular_pipeline import ( + PipelineBlock, + PipelineState, + AutoPipelineBlocks, + LoopSequentialPipelineBlocks, + BlockState, +) +from dataclasses import asdict + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + + +# YiYi experimenting composible denoise loop +# loop step (1): prepare latent input for denoiser +class StableDiffusionXLDenoiseLoopBeforeDenoiser(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return "step within the denoising loop that prepare the latent input for the denoiser. Used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object, e.g. `StableDiffusionXLDenoiseLoopWrapper`" + + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + ] + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + + block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) + + return components, block_state + +# loop step (1): prepare latent input for denoiser (with inpainting) +class StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def description(self) -> str: + return "step within the denoising loop that prepare the latent input for the denoiser (for inpainting workflow only). Used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object, e.g. `StableDiffusionXLDenoiseLoopWrapper`" + + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + InputParam( + "mask", + type_hint=Optional[torch.Tensor], + description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." + ), + InputParam( + "masked_image_latents", + type_hint=Optional[torch.Tensor], + description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." + ), + ] + + + @staticmethod + def check_inputs(components, block_state): + + num_channels_unet = components.num_channels_unet + if num_channels_unet == 9: + # default case for runwayml/stable-diffusion-inpainting + if block_state.mask is None or block_state.masked_image_latents is None: + raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") + num_channels_latents = block_state.latents.shape[1] + num_channels_mask = block_state.mask.shape[1] + num_channels_masked_image = block_state.masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: + raise ValueError( + f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" + f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `components.unet` or your `mask_image` or `image` input." + ) + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + + self.check_inputs(components, block_state) + + block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) + if components.num_channels_unet == 9: + block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) + + + return components, block_state + +# loop step (2): denoise the latents with guidance +class StableDiffusionXLDenoiseLoopDenoiser(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoise the latents with guidance. Used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object, e.g. `StableDiffusionXLDenoiseLoopWrapper`" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("cross_attention_kwargs"), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + "timestep_cond", + type_hint=Optional[torch.Tensor], + description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." + ), + InputParam( + kwargs_type="guider_input_fields", + description=( + "All conditional model inputs that need to be prepared with guider. " + "It should contain prompt_embeds/negative_prompt_embeds, " + "add_time_ids/negative_add_time_ids, " + "pooled_prompt_embeds/negative_pooled_prompt_embeds, " + "and ip_adapter_embeds/negative_ip_adapter_embeds (optional)." + "please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" + ) + ), + + ] + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int) -> PipelineState: + + # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) + # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) + guider_input_fields ={ + "prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"), + "time_ids": ("add_time_ids", "negative_add_time_ids"), + "text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), + "image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"), + } + + + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + + # Prepare mini‐batches according to guidance method and `guider_input_fields` + # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds. + # e.g. for CFG, we prepare two batches: one for uncond, one for cond + # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds + # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds + guider_state = components.guider.prepare_inputs(block_state, guider_input_fields) + + # run the denoiser for each guidance batch + for guider_state_batch in guider_state: + components.guider.prepare_models(components.unet) + cond_kwargs = guider_state_batch.as_dict() + cond_kwargs = {k:v for k,v in cond_kwargs.items() if k in guider_input_fields} + prompt_embeds = cond_kwargs.pop("prompt_embeds") + + # Predict the noise residual + # store the noise_pred in guider_state_batch so that we can apply guidance across all batches + guider_state_batch.noise_pred = components.unet( + block_state.scaled_latents, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=block_state.timestep_cond, + cross_attention_kwargs=block_state.cross_attention_kwargs, + added_cond_kwargs=cond_kwargs, + return_dict=False, + )[0] + components.guider.cleanup_models(components.unet) + + # Perform guidance + block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state) + + return components, block_state + +# loop step (2): denoise the latents with guidance (with controlnet) +class StableDiffusionXLControlNetDenoiseLoopDenoiser(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ComponentSpec("unet", UNet2DConditionModel), + ComponentSpec("controlnet", ControlNetModel), + ] + + @property + def description(self) -> str: + return "step within the denoising loop that denoise the latents with guidance (with controlnet). Used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object, e.g. `StableDiffusionXLDenoiseLoopWrapper`" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("cross_attention_kwargs"), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "controlnet_cond", + required=True, + type_hint=torch.Tensor, + description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "conditioning_scale", + type_hint=float, + description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "guess_mode", + required=True, + type_hint=bool, + description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "controlnet_keep", + required=True, + type_hint=List[float], + description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "timestep_cond", + type_hint=Optional[torch.Tensor], + description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + kwargs_type="guider_input_fields", + description=( + "All conditional model inputs that need to be prepared with guider. " + "It should contain prompt_embeds/negative_prompt_embeds, " + "add_time_ids/negative_add_time_ids, " + "pooled_prompt_embeds/negative_pooled_prompt_embeds, " + "and ip_adapter_embeds/negative_ip_adapter_embeds (optional)." + "please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" + ) + ), + InputParam( + kwargs_type="controlnet_kwargs", + description=( + "additional kwargs for controlnet (e.g. control_type_idx and control_type from the controlnet union input step )" + "please add `kwargs_type=controlnet_kwargs` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" + ) + ) + ] + + @staticmethod + def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + + accepted_kwargs = set(inspect.signature(func).parameters.keys()) + extra_kwargs = {} + for key, value in kwargs.items(): + if key in accepted_kwargs and key not in exclude_kwargs: + extra_kwargs[key] = value + + return extra_kwargs + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + + extra_controlnet_kwargs = self.prepare_extra_kwargs(components.controlnet.forward, **block_state.controlnet_kwargs) + + # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) + # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) + guider_input_fields ={ + "prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"), + "time_ids": ("add_time_ids", "negative_add_time_ids"), + "text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), + "image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"), + } + + + # cond_scale for the timestep (controlnet input) + if isinstance(block_state.controlnet_keep[i], list): + block_state.cond_scale = [c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i])] + else: + controlnet_cond_scale = block_state.conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + block_state.cond_scale = controlnet_cond_scale * block_state.controlnet_keep[i] + + # default controlnet output/unet input for guess mode + conditional path + block_state.down_block_res_samples_zeros = None + block_state.mid_block_res_sample_zeros = None + + # guided denoiser step + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + + # Prepare mini‐batches according to guidance method and `guider_input_fields` + # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds. + # e.g. for CFG, we prepare two batches: one for uncond, one for cond + # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds + # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds + guider_state = components.guider.prepare_inputs(block_state, guider_input_fields) + + # run the denoiser for each guidance batch + for guider_state_batch in guider_state: + components.guider.prepare_models(components.unet) + + # Prepare additional conditionings + added_cond_kwargs = { + "text_embeds": guider_state_batch.text_embeds, + "time_ids": guider_state_batch.time_ids, + } + if hasattr(guider_state_batch, "image_embeds") and guider_state_batch.image_embeds is not None: + added_cond_kwargs["image_embeds"] = guider_state_batch.image_embeds + + # Prepare controlnet additional conditionings + controlnet_added_cond_kwargs = { + "text_embeds": guider_state_batch.text_embeds, + "time_ids": guider_state_batch.time_ids, + } + # run controlnet for the guidance batch + if block_state.guess_mode and not components.guider.is_conditional: + # guider always run uncond batch first, so these tensors should be set already + down_block_res_samples = block_state.down_block_res_samples_zeros + mid_block_res_sample = block_state.mid_block_res_sample_zeros + else: + down_block_res_samples, mid_block_res_sample = components.controlnet( + block_state.scaled_latents, + t, + encoder_hidden_states=guider_state_batch.prompt_embeds, + controlnet_cond=block_state.controlnet_cond, + conditioning_scale=block_state.cond_scale, + guess_mode=block_state.guess_mode, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + **extra_controlnet_kwargs, + ) + + # assign it to block_state so it will be available for the uncond guidance batch + if block_state.down_block_res_samples_zeros is None: + block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in down_block_res_samples] + if block_state.mid_block_res_sample_zeros is None: + block_state.mid_block_res_sample_zeros = torch.zeros_like(mid_block_res_sample) + + # Predict the noise + # store the noise_pred in guider_state_batch so we can apply guidance across all batches + guider_state_batch.noise_pred = components.unet( + block_state.scaled_latents, + t, + encoder_hidden_states=guider_state_batch.prompt_embeds, + timestep_cond=block_state.timestep_cond, + cross_attention_kwargs=block_state.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + return_dict=False, + )[0] + components.guider.cleanup_models(components.unet) + + # Perform guidance + block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state) + + return components, block_state + +# loop step (3): scheduler step to update latents +class StableDiffusionXLDenoiseLoopAfterDenoiser(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return "step within the denoising loop that update the latents. Used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object, e.g. `StableDiffusionXLDenoiseLoopWrapper`" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("eta", default=0.0), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam("generator"), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + #YiYi TODO: move this out of here + @staticmethod + def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + + accepted_kwargs = set(inspect.signature(func).parameters.keys()) + extra_kwargs = {} + for key, value in kwargs.items(): + if key in accepted_kwargs and key not in exclude_kwargs: + extra_kwargs[key] = value + + return extra_kwargs + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + + # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) + + + # Perform scheduler step using the predicted output + block_state.latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **block_state.scheduler_step_kwargs, return_dict=False)[0] + + if block_state.latents.dtype != block_state.latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + block_state.latents = block_state.latents.to(block_state.latents_dtype) + + return components, block_state + +# loop step (3): scheduler step to update latents (with inpainting) +class StableDiffusionXLInpaintDenoiseLoopAfterDenoiser(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def description(self) -> str: + return "step within the denoising loop that update the latents (for inpainting workflow only). Used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object, e.g. `StableDiffusionXLDenoiseLoopWrapper`" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("eta", default=0.0), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam("generator"), + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + "mask", + type_hint=Optional[torch.Tensor], + description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." + ), + InputParam( + "noise", + type_hint=Optional[torch.Tensor], + description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." + ), + InputParam( + "image_latents", + type_hint=Optional[torch.Tensor], + description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + @staticmethod + def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + + accepted_kwargs = set(inspect.signature(func).parameters.keys()) + extra_kwargs = {} + for key, value in kwargs.items(): + if key in accepted_kwargs and key not in exclude_kwargs: + extra_kwargs[key] = value + + return extra_kwargs + + def check_inputs(self, components, block_state): + if components.num_channels_unet == 4: + if block_state.image_latents is None: + raise ValueError(f"image_latents is required for this step {self.__class__.__name__}") + if block_state.mask is None: + raise ValueError(f"mask is required for this step {self.__class__.__name__}") + if block_state.noise is None: + raise ValueError(f"noise is required for this step {self.__class__.__name__}") + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + + self.check_inputs(components, block_state) + + # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) + + + # Perform scheduler step using the predicted output + block_state.latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **block_state.scheduler_step_kwargs, return_dict=False)[0] + + if block_state.latents.dtype != block_state.latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + block_state.latents = block_state.latents.to(block_state.latents_dtype) + + # adjust latent for inpainting + if components.num_channels_unet == 4: + block_state.init_latents_proper = block_state.image_latents + if i < len(block_state.timesteps) - 1: + block_state.noise_timestep = block_state.timesteps[i + 1] + block_state.init_latents_proper = components.scheduler.add_noise( + block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) + ) + + block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents + + + + return components, block_state + + +# the loop wrapper that iterates over the timesteps +class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks): + + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return ( + "Pipeline block that iteratively denoise the latents over `timesteps`. The specific steps with each iteration can be customized with `blocks` attributes" + ) + + @property + def loop_expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def loop_intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." + ), + ] + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False + if block_state.disable_guidance: + components.guider.disable() + else: + components.guider.enable() + + block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) + + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components, block_state = self.loop_step(components, block_state, i=i, t=t) + if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): + progress_bar.update() + + self.add_block_state(state, block_state) + + return components, state + + +# composing the denoising loops +class StableDiffusionXLDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [StableDiffusionXLDenoiseLoopBeforeDenoiser, StableDiffusionXLDenoiseLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. " + "Its loop logic is defined in parent class `StableDiffusionXLDenoiseLoopWrapper` " + "and at each iteration, it runs blocks defined in `blocks` sequencially, i.e. `StableDiffusionXLDenoiseLoopBeforeDenoiser` and `StableDiffusionXLDenoiseLoopDenoiser`, " + "and finally `StableDiffusionXLDenoiseLoopAfterDenoiser` to update the latents." + ) + +# control_cond +class StableDiffusionXLControlNetDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [StableDiffusionXLDenoiseLoopBeforeDenoiser, StableDiffusionXLControlNetDenoiseLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents with controlnet. " + "Its loop logic is defined in parent class `StableDiffusionXLDenoiseLoopWrapper` " + "and at each iteration, it runs blocks defined in `blocks` sequencially, i.e. `StableDiffusionXLDenoiseLoopBeforeDenoiser` and `StableDiffusionXLControlNetDenoiseLoopDenoiser`, " + "and finally `StableDiffusionXLDenoiseLoopAfterDenoiser` to update the latents." + ) + +# mask +class StableDiffusionXLInpaintDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser, StableDiffusionXLDenoiseLoopDenoiser, StableDiffusionXLInpaintDenoiseLoopAfterDenoiser] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents(for inpainting task only). " + "Its loop logic is defined in parent class `StableDiffusionXLDenoiseLoopWrapper` " + "and at each iteration, it runs blocks defined in `blocks` sequencially, i.e. `StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser` and `StableDiffusionXLDenoiseLoopDenoiser`, " + "and finally `StableDiffusionXLInpaintDenoiseLoopAfterDenoiser` to update the latents." + ) +# control_cond + mask +class StableDiffusionXLInpaintControlNetDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser, StableDiffusionXLControlNetDenoiseLoopDenoiser, StableDiffusionXLInpaintDenoiseLoopAfterDenoiser] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents(for inpainting task only) with controlnet. " + "Its loop logic is defined in parent class `StableDiffusionXLDenoiseLoopWrapper` " + "and at each iteration, it runs blocks defined in `blocks` sequencially, i.e. `StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser` and `StableDiffusionXLControlNetDenoiseLoopDenoiser`, " + "and finally `StableDiffusionXLInpaintDenoiseLoopAfterDenoiser` to update the latents." + ) + + +# all task without controlnet +class StableDiffusionXLDenoiseStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintDenoiseLoop, StableDiffusionXLDenoiseLoop] + block_names = ["inpaint_denoise", "denoise"] + block_trigger_inputs = ["mask", None] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. " + "This is a auto pipeline block that works for text2img, img2img and inpainting tasks." + " - `StableDiffusionXLDenoiseStep` (denoise) is used when no mask is provided." + " - `StableDiffusionXLInpaintDenoiseStep` (inpaint_denoise) is used when mask is provided." + ) + +# all task with controlnet +class StableDiffusionXLControlNetDenoiseStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintControlNetDenoiseLoop, StableDiffusionXLControlNetDenoiseLoop] + block_names = ["inpaint_controlnet_denoise", "controlnet_denoise"] + block_trigger_inputs = ["mask", None] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents with controlnet. " + "This is a auto pipeline block that works for text2img, img2img and inpainting tasks." + " - `StableDiffusionXLControlNetDenoiseStep` (controlnet_denoise) is used when no mask is provided." + " - `StableDiffusionXLInpaintControlNetDenoiseStep` (inpaint_controlnet_denoise) is used when mask is provided." + ) + +# all task with or without controlnet +class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLControlNetDenoiseStep, StableDiffusionXLDenoiseStep] + block_names = ["controlnet_denoise", "denoise"] + block_trigger_inputs = ["controlnet_cond", None] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. " + "This is a auto pipeline block that works for text2img, img2img and inpainting tasks. And can be used with or without controlnet." + " - `StableDiffusionXLDenoiseStep` (denoise) is used when no controlnet_cond is provided (work for text2img, img2img and inpainting tasks)." + " - `StableDiffusionXLControlNetDenoiseStep` (controlnet_denoise) is used when controlnet_cond is provided (work for text2img, img2img and inpainting tasks)." + ) + + + + + + + +# YiYi Notes: alternatively, this is you can just write the denoise loop using a pipeline block, easier but not composible +# class StableDiffusionXLDenoiseStep(PipelineBlock): + +# model_name = "stable-diffusion-xl" + +# @property +# def expected_components(self) -> List[ComponentSpec]: +# return [ +# ComponentSpec( +# "guider", +# ClassifierFreeGuidance, +# config=FrozenDict({"guidance_scale": 7.5}), +# default_creation_method="from_config"), +# ComponentSpec("scheduler", EulerDiscreteScheduler), +# ComponentSpec("unet", UNet2DConditionModel), +# ] + +# @property +# def description(self) -> str: +# return ( +# "Step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process" +# ) + +# @property +# def inputs(self) -> List[Tuple[str, Any]]: +# return [ +# InputParam("cross_attention_kwargs"), +# InputParam("generator"), +# InputParam("eta", default=0.0), +# InputParam("num_images_per_prompt", default=1), +# ] + +# @property +# def intermediates_inputs(self) -> List[str]: +# return [ +# InputParam( +# "latents", +# required=True, +# type_hint=torch.Tensor, +# description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." +# ), +# InputParam( +# "batch_size", +# required=True, +# type_hint=int, +# description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." +# ), +# InputParam( +# "timesteps", +# required=True, +# type_hint=torch.Tensor, +# description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." +# ), +# InputParam( +# "num_inference_steps", +# required=True, +# type_hint=int, +# description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." +# ), +# InputParam( +# "pooled_prompt_embeds", +# required=True, +# type_hint=torch.Tensor, +# description="The pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "negative_pooled_prompt_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " +# ), +# InputParam( +# "add_time_ids", +# required=True, +# type_hint=torch.Tensor, +# description="The time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." +# ), +# InputParam( +# "negative_add_time_ids", +# type_hint=Optional[torch.Tensor], +# description="The negative time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." +# ), +# InputParam( +# "prompt_embeds", +# required=True, +# type_hint=torch.Tensor, +# description="The prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "negative_prompt_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " +# ), +# InputParam( +# "timestep_cond", +# type_hint=Optional[torch.Tensor], +# description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." +# ), +# InputParam( +# "mask", +# type_hint=Optional[torch.Tensor], +# description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "masked_image_latents", +# type_hint=Optional[torch.Tensor], +# description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "noise", +# type_hint=Optional[torch.Tensor], +# description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." +# ), +# InputParam( +# "image_latents", +# type_hint=Optional[torch.Tensor], +# description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "ip_adapter_embeds", +# type_hint=Optional[torch.Tensor], +# description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." +# ), +# InputParam( +# "negative_ip_adapter_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." +# ), +# ] + +# @property +# def intermediates_outputs(self) -> List[OutputParam]: +# return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + +# @staticmethod +# def check_inputs(components, block_state): + +# num_channels_unet = components.unet.config.in_channels +# if num_channels_unet == 9: +# # default case for runwayml/stable-diffusion-inpainting +# if block_state.mask is None or block_state.masked_image_latents is None: +# raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") +# num_channels_latents = block_state.latents.shape[1] +# num_channels_mask = block_state.mask.shape[1] +# num_channels_masked_image = block_state.masked_image_latents.shape[1] +# if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: +# raise ValueError( +# f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" +# f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" +# f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" +# f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" +# " `components.unet` or your `mask_image` or `image` input." +# ) + +# # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components +# @staticmethod +# def prepare_extra_step_kwargs(components, generator, eta): +# # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature +# # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. +# # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 +# # and should be between [0, 1] + +# accepts_eta = "eta" in set(inspect.signature(components.scheduler.step).parameters.keys()) +# extra_step_kwargs = {} +# if accepts_eta: +# extra_step_kwargs["eta"] = eta + +# # check if the scheduler accepts generator +# accepts_generator = "generator" in set(inspect.signature(components.scheduler.step).parameters.keys()) +# if accepts_generator: +# extra_step_kwargs["generator"] = generator +# return extra_step_kwargs + +# @torch.no_grad() +# def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + +# block_state = self.get_block_state(state) +# self.check_inputs(components, block_state) + +# block_state.num_channels_unet = components.unet.config.in_channels +# block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False +# if block_state.disable_guidance: +# components.guider.disable() +# else: +# components.guider.enable() + +# # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline +# block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) +# block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) + +# components.guider.set_input_fields( +# prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), +# add_time_ids=("add_time_ids", "negative_add_time_ids"), +# pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), +# ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), +# ) + +# with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: +# for i, t in enumerate(block_state.timesteps): +# components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) +# guider_data = components.guider.prepare_inputs(block_state) + +# block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) + +# # Prepare for inpainting +# if block_state.num_channels_unet == 9: +# block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) + +# for batch in guider_data: +# components.guider.prepare_models(components.unet) + +# # Prepare additional conditionings +# batch.added_cond_kwargs = { +# "text_embeds": batch.pooled_prompt_embeds, +# "time_ids": batch.add_time_ids, +# } +# if batch.ip_adapter_embeds is not None: +# batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds + +# # Predict the noise residual +# batch.noise_pred = components.unet( +# block_state.scaled_latents, +# t, +# encoder_hidden_states=batch.prompt_embeds, +# timestep_cond=block_state.timestep_cond, +# cross_attention_kwargs=block_state.cross_attention_kwargs, +# added_cond_kwargs=batch.added_cond_kwargs, +# return_dict=False, +# )[0] +# components.guider.cleanup_models(components.unet) + +# # Perform guidance +# block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_data) + +# # Perform scheduler step using the predicted output +# block_state.latents_dtype = block_state.latents.dtype +# block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] + +# if block_state.latents.dtype != block_state.latents_dtype: +# if torch.backends.mps.is_available(): +# # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 +# block_state.latents = block_state.latents.to(block_state.latents_dtype) + +# if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: +# block_state.init_latents_proper = block_state.image_latents +# if i < len(block_state.timesteps) - 1: +# block_state.noise_timestep = block_state.timesteps[i + 1] +# block_state.init_latents_proper = components.scheduler.add_noise( +# block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) +# ) + +# block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents + +# if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): +# progress_bar.update() + +# self.add_block_state(state, block_state) + +# return components, state + + + +# class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): + +# model_name = "stable-diffusion-xl" + +# @property +# def expected_components(self) -> List[ComponentSpec]: +# return [ +# ComponentSpec( +# "guider", +# ClassifierFreeGuidance, +# config=FrozenDict({"guidance_scale": 7.5}), +# default_creation_method="from_config"), +# ComponentSpec("scheduler", EulerDiscreteScheduler), +# ComponentSpec("unet", UNet2DConditionModel), +# ComponentSpec("controlnet", ControlNetModel), +# ] + +# @property +# def description(self) -> str: +# return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" + +# @property +# def inputs(self) -> List[Tuple[str, Any]]: +# return [ +# InputParam("num_images_per_prompt", default=1), +# InputParam("cross_attention_kwargs"), +# InputParam("generator"), +# InputParam("eta", default=0.0), +# InputParam("controlnet_conditioning_scale", type_hint=float, default=1.0), # can expect either input or intermediate input, (intermediate input if both are passed) +# ] + +# @property +# def intermediates_inputs(self) -> List[str]: +# return [ +# InputParam( +# "controlnet_cond", +# required=True, +# type_hint=torch.Tensor, +# description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "control_guidance_start", +# required=True, +# type_hint=float, +# description="The control guidance start value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "control_guidance_end", +# required=True, +# type_hint=float, +# description="The control guidance end value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "conditioning_scale", +# type_hint=float, +# description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "guess_mode", +# required=True, +# type_hint=bool, +# description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "controlnet_keep", +# required=True, +# type_hint=List[float], +# description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "latents", +# required=True, +# type_hint=torch.Tensor, +# description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." +# ), +# InputParam( +# "batch_size", +# required=True, +# type_hint=int, +# description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." +# ), +# InputParam( +# "timesteps", +# required=True, +# type_hint=torch.Tensor, +# description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." +# ), +# InputParam( +# "prompt_embeds", +# required=True, +# type_hint=torch.Tensor, +# description="The prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "negative_prompt_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "add_time_ids", +# required=True, +# type_hint=torch.Tensor, +# description="The time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." +# ), +# InputParam( +# "negative_add_time_ids", +# type_hint=Optional[torch.Tensor], +# description="The negative time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." +# ), +# InputParam( +# "pooled_prompt_embeds", +# required=True, +# type_hint=torch.Tensor, +# description="The pooled prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "negative_pooled_prompt_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "timestep_cond", +# type_hint=Optional[torch.Tensor], +# description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" +# ), +# InputParam( +# "mask", +# type_hint=Optional[torch.Tensor], +# description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "masked_image_latents", +# type_hint=Optional[torch.Tensor], +# description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "noise", +# type_hint=Optional[torch.Tensor], +# description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." +# ), +# InputParam( +# "image_latents", +# type_hint=Optional[torch.Tensor], +# description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "crops_coords", +# type_hint=Optional[Tuple[int]], +# description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." +# ), +# InputParam( +# "ip_adapter_embeds", +# type_hint=Optional[torch.Tensor], +# description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." +# ), +# InputParam( +# "negative_ip_adapter_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." +# ), +# InputParam( +# "num_inference_steps", +# required=True, +# type_hint=int, +# description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." +# ), +# InputParam(kwargs_type="controlnet_kwargs", description="additional kwargs for controlnet") +# ] + +# @property +# def intermediates_outputs(self) -> List[OutputParam]: +# return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + +# @staticmethod +# def check_inputs(components, block_state): + +# num_channels_unet = components.unet.config.in_channels +# if num_channels_unet == 9: +# # default case for runwayml/stable-diffusion-inpainting +# if block_state.mask is None or block_state.masked_image_latents is None: +# raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") +# num_channels_latents = block_state.latents.shape[1] +# num_channels_mask = block_state.mask.shape[1] +# num_channels_masked_image = block_state.masked_image_latents.shape[1] +# if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: +# raise ValueError( +# f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" +# f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" +# f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" +# f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" +# " `components.unet` or your `mask_image` or `image` input." +# ) +# @staticmethod +# def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + +# accepted_kwargs = set(inspect.signature(func).parameters.keys()) +# extra_kwargs = {} +# for key, value in kwargs.items(): +# if key in accepted_kwargs and key not in exclude_kwargs: +# extra_kwargs[key] = value + +# return extra_kwargs + + +# @torch.no_grad() +# def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + +# block_state = self.get_block_state(state) +# self.check_inputs(components, block_state) +# block_state.device = components._execution_device +# print(f" block_state: {block_state}") + +# controlnet = unwrap_module(components.controlnet) + +# # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline +# block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) +# block_state.extra_controlnet_kwargs = self.prepare_extra_kwargs(controlnet.forward, exclude_kwargs=["controlnet_cond", "conditioning_scale", "guess_mode"], **block_state.controlnet_kwargs) + +# block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) + +# # (1) setup guider +# # disable for LCMs +# block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False +# if block_state.disable_guidance: +# components.guider.disable() +# else: +# components.guider.enable() +# components.guider.set_input_fields( +# prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), +# add_time_ids=("add_time_ids", "negative_add_time_ids"), +# pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), +# ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), +# ) + +# # (5) Denoise loop +# with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: +# for i, t in enumerate(block_state.timesteps): + +# # prepare latent input for unet +# block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) +# # adjust latent input for inpainting +# block_state.num_channels_unet = components.unet.config.in_channels +# if block_state.num_channels_unet == 9: +# block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) + + +# # cond_scale (controlnet input) +# if isinstance(block_state.controlnet_keep[i], list): +# block_state.cond_scale = [c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i])] +# else: +# block_state.controlnet_cond_scale = block_state.conditioning_scale +# if isinstance(block_state.controlnet_cond_scale, list): +# block_state.controlnet_cond_scale = block_state.controlnet_cond_scale[0] +# block_state.cond_scale = block_state.controlnet_cond_scale * block_state.controlnet_keep[i] + +# # default controlnet output/unet input for guess mode + conditional path +# block_state.down_block_res_samples_zeros = None +# block_state.mid_block_res_sample_zeros = None + +# # guided denoiser step +# components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) +# guider_state = components.guider.prepare_inputs(block_state) + +# for guider_state_batch in guider_state: +# components.guider.prepare_models(components.unet) + +# # Prepare additional conditionings +# guider_state_batch.added_cond_kwargs = { +# "text_embeds": guider_state_batch.pooled_prompt_embeds, +# "time_ids": guider_state_batch.add_time_ids, +# } +# if guider_state_batch.ip_adapter_embeds is not None: +# guider_state_batch.added_cond_kwargs["image_embeds"] = guider_state_batch.ip_adapter_embeds + +# # Prepare controlnet additional conditionings +# guider_state_batch.controlnet_added_cond_kwargs = { +# "text_embeds": guider_state_batch.pooled_prompt_embeds, +# "time_ids": guider_state_batch.add_time_ids, +# } + +# if block_state.guess_mode and not components.guider.is_conditional: +# # guider always run uncond batch first, so these tensors should be set already +# guider_state_batch.down_block_res_samples = block_state.down_block_res_samples_zeros +# guider_state_batch.mid_block_res_sample = block_state.mid_block_res_sample_zeros +# else: +# guider_state_batch.down_block_res_samples, guider_state_batch.mid_block_res_sample = components.controlnet( +# block_state.scaled_latents, +# t, +# encoder_hidden_states=guider_state_batch.prompt_embeds, +# controlnet_cond=block_state.controlnet_cond, +# conditioning_scale=block_state.conditioning_scale, +# guess_mode=block_state.guess_mode, +# added_cond_kwargs=guider_state_batch.controlnet_added_cond_kwargs, +# return_dict=False, +# **block_state.extra_controlnet_kwargs, +# ) + +# if block_state.down_block_res_samples_zeros is None: +# block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in guider_state_batch.down_block_res_samples] +# if block_state.mid_block_res_sample_zeros is None: +# block_state.mid_block_res_sample_zeros = torch.zeros_like(guider_state_batch.mid_block_res_sample) + + + +# guider_state_batch.noise_pred = components.unet( +# block_state.scaled_latents, +# t, +# encoder_hidden_states=guider_state_batch.prompt_embeds, +# timestep_cond=block_state.timestep_cond, +# cross_attention_kwargs=block_state.cross_attention_kwargs, +# added_cond_kwargs=guider_state_batch.added_cond_kwargs, +# down_block_additional_residuals=guider_state_batch.down_block_res_samples, +# mid_block_additional_residual=guider_state_batch.mid_block_res_sample, +# return_dict=False, +# )[0] +# components.guider.cleanup_models(components.unet) + +# # Perform guidance +# block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_state) + +# # Perform scheduler step using the predicted output +# block_state.latents_dtype = block_state.latents.dtype +# block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] + +# if block_state.latents.dtype != block_state.latents_dtype: +# if torch.backends.mps.is_available(): +# # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 +# block_state.latents = block_state.latents.to(block_state.latents_dtype) + +# # adjust latent for inpainting +# if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: +# block_state.init_latents_proper = block_state.image_latents +# if i < len(block_state.timesteps) - 1: +# block_state.noise_timestep = block_state.timesteps[i + 1] +# block_state.init_latents_proper = components.scheduler.add_noise( +# block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) +# ) + +# block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents + +# if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): +# progress_bar.update() + +# self.add_block_state(state, block_state) + +# return components, state \ No newline at end of file diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py new file mode 100644 index 000000000000..ca4efe2c4a7f --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py @@ -0,0 +1,858 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, List, Optional, Tuple, Union, Dict + +import PIL +import torch +from collections import OrderedDict + +from ...image_processor import VaeImageProcessor, PipelineImageInput +from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin +from ...models import ControlNetModel, ImageProjection, UNet2DConditionModel, AutoencoderKL, ControlNetUnionModel +from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor +from ...models.lora import adjust_lora_scale_text_encoder +from ...utils import ( + USE_PEFT_BACKEND, + logging, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor, unwrap_module +from ...pipelines.controlnet.multicontrolnet import MultiControlNetModel +from ...configuration_utils import FrozenDict + +from transformers import ( + CLIPTextModel, + CLIPImageProcessor, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...schedulers import EulerDiscreteScheduler +from ...guiders import ClassifierFreeGuidance + +from .modular_loader import StableDiffusionXLModularLoader +from ..modular_pipeline import PipelineBlock, PipelineState, AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam, ConfigSpec + +import numpy as np + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class StableDiffusionXLIPAdapterStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + + @property + def description(self) -> str: + return ( + "IP Adapter step that handles all the ip adapter related tasks: Load/unload ip adapter weights into unet, prepare ip adapter image embeddings, etc" + " See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)" + " for more details" + ) + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("image_encoder", CLIPVisionModelWithProjection), + ComponentSpec("feature_extractor", CLIPImageProcessor, config=FrozenDict({"size": 224, "crop_size": 224}), default_creation_method="from_config"), + ComponentSpec("unet", UNet2DConditionModel), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + "ip_adapter_image", + PipelineImageInput, + required=True, + description="The image(s) to be used as ip adapter" + ) + ] + + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [ + OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"), + OutputParam("negative_ip_adapter_embeds", type_hint=torch.Tensor, description="Negative IP adapter image embeddings") + ] + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image with self -> components + @staticmethod + def encode_image(components, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(components.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = components.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = components.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = components.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = components.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, prepare_unconditional_embeds + ): + image_embeds = [] + if prepare_unconditional_embeds: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, components.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + components, single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if prepare_unconditional_embeds: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if prepare_unconditional_embeds: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if prepare_unconditional_embeds: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 + block_state.device = components._execution_device + + block_state.ip_adapter_embeds = self.prepare_ip_adapter_image_embeds( + components, + ip_adapter_image=block_state.ip_adapter_image, + ip_adapter_image_embeds=None, + device=block_state.device, + num_images_per_prompt=1, + prepare_unconditional_embeds=block_state.prepare_unconditional_embeds, + ) + if block_state.prepare_unconditional_embeds: + block_state.negative_ip_adapter_embeds = [] + for i, image_embeds in enumerate(block_state.ip_adapter_embeds): + negative_image_embeds, image_embeds = image_embeds.chunk(2) + block_state.negative_ip_adapter_embeds.append(negative_image_embeds) + block_state.ip_adapter_embeds[i] = image_embeds + + self.add_block_state(state, block_state) + return components, state + + +class StableDiffusionXLTextEncoderStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return( + "Text Encoder step that generate text_embeddings to guide the image generation" + ) + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("text_encoder", CLIPTextModel), + ComponentSpec("text_encoder_2", CLIPTextModelWithProjection), + ComponentSpec("tokenizer", CLIPTokenizer), + ComponentSpec("tokenizer_2", CLIPTokenizer), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ] + + @property + def expected_configs(self) -> List[ConfigSpec]: + return [ConfigSpec("force_zeros_for_empty_prompt", True)] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("prompt"), + InputParam("prompt_2"), + InputParam("negative_prompt"), + InputParam("negative_prompt_2"), + InputParam("cross_attention_kwargs"), + InputParam("clip_skip"), + ] + + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [ + OutputParam("prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields",description="text embeddings used to guide the image generation"), + OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative text embeddings used to guide the image generation"), + OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="pooled text embeddings used to guide the image generation"), + OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative pooled text embeddings used to guide the image generation"), + ] + + @staticmethod + def check_inputs(block_state): + + if block_state.prompt is not None and (not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}") + elif block_state.prompt_2 is not None and (not isinstance(block_state.prompt_2, str) and not isinstance(block_state.prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(block_state.prompt_2)}") + + @staticmethod + def encode_prompt( + components, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prepare_unconditional_embeds: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prepare_unconditional_embeds (`bool`): + whether to use prepare unconditional embeddings or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or components._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(components, StableDiffusionXLLoraLoaderMixin): + components._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if components.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(components.text_encoder, lora_scale) + else: + scale_lora_layers(components.text_encoder, lora_scale) + + if components.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(components.text_encoder_2, lora_scale) + else: + scale_lora_layers(components.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [components.tokenizer, components.tokenizer_2] if components.tokenizer is not None else [components.tokenizer_2] + text_encoders = ( + [components.text_encoder, components.text_encoder_2] if components.text_encoder is not None else [components.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(components, TextualInversionLoaderMixin): + prompt = components.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and components.config.force_zeros_for_empty_prompt + if prepare_unconditional_embeds and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif prepare_unconditional_embeds and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(components, TextualInversionLoaderMixin): + negative_prompt = components.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if components.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=components.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if prepare_unconditional_embeds: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if components.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if prepare_unconditional_embeds: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if components.text_encoder is not None: + if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(components.text_encoder, lora_scale) + + if components.text_encoder_2 is not None: + if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(components.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + # Get inputs and intermediates + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 + block_state.device = components._execution_device + + # Encode input prompt + block_state.text_encoder_lora_scale = ( + block_state.cross_attention_kwargs.get("scale", None) if block_state.cross_attention_kwargs is not None else None + ) + ( + block_state.prompt_embeds, + block_state.negative_prompt_embeds, + block_state.pooled_prompt_embeds, + block_state.negative_pooled_prompt_embeds, + ) = self.encode_prompt( + components, + block_state.prompt, + block_state.prompt_2, + block_state.device, + 1, + block_state.prepare_unconditional_embeds, + block_state.negative_prompt, + block_state.negative_prompt_2, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + lora_scale=block_state.text_encoder_lora_scale, + clip_skip=block_state.clip_skip, + ) + # Add outputs + self.add_block_state(state, block_state) + return components, state + + +class StableDiffusionXLVaeEncoderStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + + @property + def description(self) -> str: + return ( + "Vae Encoder step that encode the input image into a latent representation" + ) + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config"), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("image", required=True), + InputParam("height"), + InputParam("width"), + ] + + @property + def intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam("generator"), + InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), + InputParam("preprocess_kwargs", type_hint=Optional[dict], description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]")] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation")] + + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components + # YiYi TODO: update the _encode_vae_image so that we can use #Coped from + def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator): + + latents_mean = latents_std = None + if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: + latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: + latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) + + dtype = image.dtype + if components.vae.config.force_upcast: + image = image.float() + components.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(components.vae.encode(image), generator=generator) + + if components.vae.config.force_upcast: + components.vae.to(dtype) + + image_latents = image_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) + latents_std = latents_std.to(device=image_latents.device, dtype=dtype) + image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std + else: + image_latents = components.vae.config.scaling_factor * image_latents + + return image_latents + + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.preprocess_kwargs = block_state.preprocess_kwargs or {} + block_state.device = components._execution_device + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + + block_state.image = components.image_processor.preprocess(block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs) + block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype) + + block_state.batch_size = block_state.image.shape[0] + + # if generator is a list, make sure the length of it matches the length of images (both should be batch_size) + if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch" + f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators." + ) + + + block_state.image_latents = self._encode_vae_image(components, image=block_state.image, generator=block_state.generator) + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config"), + ComponentSpec( + "mask_processor", + VaeImageProcessor, + config=FrozenDict({"do_normalize": False, "vae_scale_factor": 8, "do_binarize": True, "do_convert_grayscale": True}), + default_creation_method="from_config"), + ] + + + @property + def description(self) -> str: + return ( + "Vae encoder step that prepares the image and mask for the inpainting process" + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("height"), + InputParam("width"), + InputParam("image", required=True), + InputParam("mask_image", required=True), + InputParam("padding_mask_crop"), + ] + + @property + def intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), + InputParam("generator"), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representation of the input image"), + OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"), + OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)"), + OutputParam("crops_coords", type_hint=Optional[Tuple[int, int]], description="The crop coordinates to use for the preprocess/postprocess of the image and mask")] + + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components + # YiYi TODO: update the _encode_vae_image so that we can use #Coped from + def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator): + + latents_mean = latents_std = None + if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: + latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: + latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) + + dtype = image.dtype + if components.vae.config.force_upcast: + image = image.float() + components.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(components.vae.encode(image), generator=generator) + + if components.vae.config.force_upcast: + components.vae.to(dtype) + + image_latents = image_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) + latents_std = latents_std.to(device=image_latents.device, dtype=dtype) + image_latents = (image_latents - latents_mean) * self.vae.config.scaling_factor / latents_std + else: + image_latents = components.vae.config.scaling_factor * image_latents + + return image_latents + + # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents + # do not accept do_classifier_free_guidance + def prepare_mask_latents( + self, components, mask, masked_image, batch_size, height, width, dtype, device, generator + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + + if masked_image is not None and masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + masked_image_latents = None + + if masked_image is not None: + if masked_image_latents is None: + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator) + + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat( + batch_size // masked_image_latents.shape[0], 1, 1, 1 + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + return mask, masked_image_latents + + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + + block_state = self.get_block_state(state) + + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + block_state.device = components._execution_device + + if block_state.padding_mask_crop is not None: + block_state.crops_coords = components.mask_processor.get_crop_region(block_state.mask_image, block_state.width, block_state.height, pad=block_state.padding_mask_crop) + block_state.resize_mode = "fill" + else: + block_state.crops_coords = None + block_state.resize_mode = "default" + + block_state.image = components.image_processor.preprocess(block_state.image, height=block_state.height, width=block_state.width, crops_coords=block_state.crops_coords, resize_mode=block_state.resize_mode) + block_state.image = block_state.image.to(dtype=torch.float32) + + block_state.mask = components.mask_processor.preprocess(block_state.mask_image, height=block_state.height, width=block_state.width, resize_mode=block_state.resize_mode, crops_coords=block_state.crops_coords) + block_state.masked_image = block_state.image * (block_state.mask < 0.5) + + block_state.batch_size = block_state.image.shape[0] + block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype) + block_state.image_latents = self._encode_vae_image(components, image=block_state.image, generator=block_state.generator) + + # 7. Prepare mask latent variables + block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents( + components, + block_state.mask, + block_state.masked_image, + block_state.batch_size, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + ) + + self.add_block_state(state, block_state) + + + return components, state + + + +# auto blocks (YiYi TODO: maybe move all the auto blocks to a separate file) +# Encode +class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintVaeEncoderStep, StableDiffusionXLVaeEncoderStep] + block_names = ["inpaint", "img2img"] + block_trigger_inputs = ["mask_image", "image"] + + @property + def description(self): + return "Vae encoder step that encode the image inputs into their latent representations.\n" + \ + "This is an auto pipeline block that works for both inpainting and img2img tasks.\n" + \ + " - `StableDiffusionXLInpaintVaeEncoderStep` (inpaint) is used when both `mask_image` and `image` are provided.\n" + \ + " - `StableDiffusionXLVaeEncoderStep` (img2img) is used when only `image` is provided." + + +class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks, ModularIPAdapterMixin): + block_classes = [StableDiffusionXLIPAdapterStep] + block_names = ["ip_adapter"] + block_trigger_inputs = ["ip_adapter_image"] + + @property + def description(self): + return "Run IP Adapter step if `ip_adapter_image` is provided." + diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py new file mode 100644 index 000000000000..4af942af64e6 --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py @@ -0,0 +1,174 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, Optional, Tuple, Union, Dict +import PIL +import torch +import numpy as np + +from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin +from ...image_processor import PipelineImageInput +from ...pipelines.pipeline_utils import StableDiffusionMixin +from ...pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from ...utils import logging + +from ..modular_pipeline import ModularLoader +from ..modular_pipeline_utils import InputParam, OutputParam + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + + +# YiYi TODO: move to a different file? stable_diffusion_xl_module should have its own folder? +# YiYi Notes: model specific components: +## (1) it should inherit from ModularLoader +## (2) acts like a container that holds components and configs +## (3) define default config (related to components), e.g. default_sample_size, vae_scale_factor, num_channels_unet, num_channels_latents +## (4) inherit from model-specic loader class (e.g. StableDiffusionXLLoraLoaderMixin) +## (5) how to use together with Components_manager? +class StableDiffusionXLModularLoader( + ModularLoader, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + ModularIPAdapterMixin, +): + @property + def default_sample_size(self): + default_sample_size = 128 + if hasattr(self, "unet") and self.unet is not None: + default_sample_size = self.unet.config.sample_size + return default_sample_size + + @property + def vae_scale_factor(self): + vae_scale_factor = 8 + if hasattr(self, "vae") and self.vae is not None: + vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + return vae_scale_factor + + @property + def num_channels_unet(self): + num_channels_unet = 4 + if hasattr(self, "unet") and self.unet is not None: + num_channels_unet = self.unet.config.in_channels + return num_channels_unet + + @property + def num_channels_latents(self): + num_channels_latents = 4 + if hasattr(self, "vae") and self.vae is not None: + num_channels_latents = self.vae.config.latent_channels + return num_channels_latents + + + +# YiYi Notes: not used yet, maintain a list of schema that can be used across all pipeline blocks +SDXL_INPUTS_SCHEMA = { + "prompt": InputParam("prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"), + "prompt_2": InputParam("prompt_2", type_hint=Union[str, List[str]], description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2"), + "negative_prompt": InputParam("negative_prompt", type_hint=Union[str, List[str]], description="The prompt or prompts not to guide the image generation"), + "negative_prompt_2": InputParam("negative_prompt_2", type_hint=Union[str, List[str]], description="The negative prompt or prompts for text_encoder_2"), + "cross_attention_kwargs": InputParam("cross_attention_kwargs", type_hint=Optional[dict], description="Kwargs dictionary passed to the AttentionProcessor"), + "clip_skip": InputParam("clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"), + "image": InputParam("image", type_hint=PipelineImageInput, required=True, description="The image(s) to modify for img2img or inpainting"), + "mask_image": InputParam("mask_image", type_hint=PipelineImageInput, required=True, description="Mask image for inpainting, white pixels will be repainted"), + "generator": InputParam("generator", type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], description="Generator(s) for deterministic generation"), + "height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"), + "width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"), + "num_images_per_prompt": InputParam("num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"), + "num_inference_steps": InputParam("num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"), + "timesteps": InputParam("timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"), + "sigmas": InputParam("sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"), + "denoising_end": InputParam("denoising_end", type_hint=Optional[float], description="Fraction of denoising process to complete before termination"), + # YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999 + "strength": InputParam("strength", type_hint=float, default=0.3, description="How much to transform the reference image"), + "denoising_start": InputParam("denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"), + "latents": InputParam("latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"), + "padding_mask_crop": InputParam("padding_mask_crop", type_hint=Optional[Tuple[int, int]], description="Size of margin in crop for image and mask"), + "original_size": InputParam("original_size", type_hint=Optional[Tuple[int, int]], description="Original size of the image for SDXL's micro-conditioning"), + "target_size": InputParam("target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"), + "negative_original_size": InputParam("negative_original_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on image resolution"), + "negative_target_size": InputParam("negative_target_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on target resolution"), + "crops_coords_top_left": InputParam("crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Top-left coordinates for SDXL's micro-conditioning"), + "negative_crops_coords_top_left": InputParam("negative_crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Negative conditioning crop coordinates"), + "aesthetic_score": InputParam("aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"), + "negative_aesthetic_score": InputParam("negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"), + "eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"), + "output_type": InputParam("output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"), + "ip_adapter_image": InputParam("ip_adapter_image", type_hint=PipelineImageInput, required=True, description="Image(s) to be used as IP adapter"), + "control_image": InputParam("control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"), + "control_guidance_start": InputParam("control_guidance_start", type_hint=Union[float, List[float]], default=0.0, description="When ControlNet starts applying"), + "control_guidance_end": InputParam("control_guidance_end", type_hint=Union[float, List[float]], default=1.0, description="When ControlNet stops applying"), + "controlnet_conditioning_scale": InputParam("controlnet_conditioning_scale", type_hint=Union[float, List[float]], default=1.0, description="Scale factor for ControlNet outputs"), + "guess_mode": InputParam("guess_mode", type_hint=bool, default=False, description="Enables ControlNet encoder to recognize input without prompts"), + "control_mode": InputParam("control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet") +} + + +SDXL_INTERMEDIATE_INPUTS_SCHEMA = { + "prompt_embeds": InputParam("prompt_embeds", type_hint=torch.Tensor, required=True, description="Text embeddings used to guide image generation"), + "negative_prompt_embeds": InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"), + "pooled_prompt_embeds": InputParam("pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"), + "negative_pooled_prompt_embeds": InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"), + "batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"), + "dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), + "preprocess_kwargs": InputParam("preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"), + "latents": InputParam("latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"), + "timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"), + "num_inference_steps": InputParam("num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"), + "latent_timestep": InputParam("latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"), + "image_latents": InputParam("image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"), + "mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"), + "masked_image_latents": InputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"), + "add_time_ids": InputParam("add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"), + "negative_add_time_ids": InputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"), + "timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"), + "noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"), + "crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"), + "ip_adapter_embeds": InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"), + "negative_ip_adapter_embeds": InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"), + "images": InputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], required=True, description="Generated images") +} + + +SDXL_INTERMEDIATE_OUTPUTS_SCHEMA = { + "prompt_embeds": OutputParam("prompt_embeds", type_hint=torch.Tensor, description="Text embeddings used to guide image generation"), + "negative_prompt_embeds": OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"), + "pooled_prompt_embeds": OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, description="Pooled text embeddings"), + "negative_pooled_prompt_embeds": OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"), + "batch_size": OutputParam("batch_size", type_hint=int, description="Number of prompts"), + "dtype": OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), + "image_latents": OutputParam("image_latents", type_hint=torch.Tensor, description="Latents representing reference image"), + "mask": OutputParam("mask", type_hint=torch.Tensor, description="Mask for inpainting"), + "masked_image_latents": OutputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"), + "crops_coords": OutputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"), + "timesteps": OutputParam("timesteps", type_hint=torch.Tensor, description="Timesteps for inference"), + "num_inference_steps": OutputParam("num_inference_steps", type_hint=int, description="Number of denoising steps"), + "latent_timestep": OutputParam("latent_timestep", type_hint=torch.Tensor, description="Initial noise level timestep"), + "add_time_ids": OutputParam("add_time_ids", type_hint=torch.Tensor, description="Time ids for conditioning"), + "negative_add_time_ids": OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"), + "timestep_cond": OutputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"), + "latents": OutputParam("latents", type_hint=torch.Tensor, description="Denoised latents"), + "noise": OutputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"), + "ip_adapter_embeds": OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"), + "negative_ip_adapter_embeds": OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"), + "images": OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="Generated images") +} + + +SDXL_OUTPUTS_SCHEMA = { + "images": OutputParam("images", type_hint=Union[Tuple[Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]]], StableDiffusionXLPipelineOutput], description="The final generated images") +} + diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_block_mappings.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_block_mappings.py new file mode 100644 index 000000000000..00cd5ca3735a --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_block_mappings.py @@ -0,0 +1,126 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..modular_pipeline_utils import InsertableOrderedDict + +# Import all the necessary block classes +from .denoise import ( + StableDiffusionXLAutoDenoiseStep, + StableDiffusionXLControlNetDenoiseStep, + StableDiffusionXLDenoiseLoop, + StableDiffusionXLInpaintDenoiseLoop +) +from .before_denoise import ( + StableDiffusionXLAutoBeforeDenoiseStep, + StableDiffusionXLInputStep, + StableDiffusionXLSetTimestepsStep, + StableDiffusionXLPrepareLatentsStep, + StableDiffusionXLPrepareAdditionalConditioningStep, + StableDiffusionXLImg2ImgSetTimestepsStep, + StableDiffusionXLImg2ImgPrepareLatentsStep, + StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, + StableDiffusionXLInpaintPrepareLatentsStep, + StableDiffusionXLControlNetInputStep, + StableDiffusionXLControlNetUnionInputStep +) +from .encoders import ( + StableDiffusionXLTextEncoderStep, + StableDiffusionXLAutoIPAdapterStep, + StableDiffusionXLAutoVaeEncoderStep, + StableDiffusionXLVaeEncoderStep, + StableDiffusionXLInpaintVaeEncoderStep, + StableDiffusionXLIPAdapterStep +) +from .decoders import ( + StableDiffusionXLDecodeStep, + StableDiffusionXLInpaintDecodeStep, + StableDiffusionXLAutoDecodeStep +) + + +# YiYi notes: comment out for now, work on this later +# block mapping +TEXT2IMAGE_BLOCKS = InsertableOrderedDict([ + ("text_encoder", StableDiffusionXLTextEncoderStep), + ("input", StableDiffusionXLInputStep), + ("set_timesteps", StableDiffusionXLSetTimestepsStep), + ("prepare_latents", StableDiffusionXLPrepareLatentsStep), + ("prepare_add_cond", StableDiffusionXLPrepareAdditionalConditioningStep), + ("denoise", StableDiffusionXLDenoiseLoop), + ("decode", StableDiffusionXLDecodeStep) +]) + +IMAGE2IMAGE_BLOCKS = InsertableOrderedDict([ + ("text_encoder", StableDiffusionXLTextEncoderStep), + ("image_encoder", StableDiffusionXLVaeEncoderStep), + ("input", StableDiffusionXLInputStep), + ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), + ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep), + ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), + ("denoise", StableDiffusionXLDenoiseLoop), + ("decode", StableDiffusionXLDecodeStep) +]) + +INPAINT_BLOCKS = InsertableOrderedDict([ + ("text_encoder", StableDiffusionXLTextEncoderStep), + ("image_encoder", StableDiffusionXLInpaintVaeEncoderStep), + ("input", StableDiffusionXLInputStep), + ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), + ("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep), + ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), + ("denoise", StableDiffusionXLInpaintDenoiseLoop), + ("decode", StableDiffusionXLInpaintDecodeStep) +]) + +CONTROLNET_BLOCKS = InsertableOrderedDict([ + ("controlnet_input", StableDiffusionXLControlNetInputStep), + ("denoise", StableDiffusionXLControlNetDenoiseStep), +]) + +CONTROLNET_UNION_BLOCKS = InsertableOrderedDict([ + ("controlnet_input", StableDiffusionXLControlNetUnionInputStep), + ("denoise", StableDiffusionXLControlNetDenoiseStep), +]) + +IP_ADAPTER_BLOCKS = InsertableOrderedDict([ + ("ip_adapter", StableDiffusionXLIPAdapterStep), +]) + +AUTO_BLOCKS = InsertableOrderedDict([ + ("text_encoder", StableDiffusionXLTextEncoderStep), + ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), + ("image_encoder", StableDiffusionXLAutoVaeEncoderStep), + ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), + ("denoise", StableDiffusionXLAutoDenoiseStep), + ("decode", StableDiffusionXLAutoDecodeStep) +]) + +AUTO_CORE_BLOCKS = InsertableOrderedDict([ + ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), + ("denoise", StableDiffusionXLAutoDenoiseStep), +]) + + +SDXL_SUPPORTED_BLOCKS = { + "text2img": TEXT2IMAGE_BLOCKS, + "img2img": IMAGE2IMAGE_BLOCKS, + "inpaint": INPAINT_BLOCKS, + "controlnet": CONTROLNET_BLOCKS, + "controlnet_union": CONTROLNET_UNION_BLOCKS, + "ip_adapter": IP_ADAPTER_BLOCKS, + "auto": AUTO_BLOCKS +} + + + diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py new file mode 100644 index 000000000000..637c7ac306d7 --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py @@ -0,0 +1,43 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, Optional, Tuple, Union, Dict +from ...utils import logging +from ..modular_pipeline import SequentialPipelineBlocks + +from .denoise import StableDiffusionXLAutoDenoiseStep +from .before_denoise import StableDiffusionXLAutoBeforeDenoiseStep +from .decoders import StableDiffusionXLAutoDecodeStep +from .encoders import StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class StableDiffusionXLAutoPipeline(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep, StableDiffusionXLAutoBeforeDenoiseStep, StableDiffusionXLAutoDenoiseStep, StableDiffusionXLAutoDecodeStep] + block_names = ["text_encoder", "ip_adapter", "image_encoder", "before_denoise", "denoise", "decoder"] + + @property + def description(self): + return "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion XL.\n" + \ + "- for image-to-image generation, you need to provide either `image` or `image_latents`\n" + \ + "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n" + \ + "- to run the controlnet workflow, you need to provide `control_image`\n" + \ + "- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n" + \ + "- to run the ip_adapter workflow, you need to provide `ip_adapter_image`\n" + \ + "- for text-to-image generation, all you need to provide is `prompt`" + + + + diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 61ed023ce06b..011f23ed371c 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -47,7 +47,6 @@ "AutoPipelineForInpainting", "AutoPipelineForText2Image", ] - _import_structure["modular_pipeline"] = ["ModularPipeline"] _import_structure["consistency_models"] = ["ConsistencyModelPipeline"] _import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"] _import_structure["ddim"] = ["DDIMPipeline"] @@ -330,8 +329,6 @@ "StableDiffusionXLInpaintPipeline", "StableDiffusionXLInstructPix2PixPipeline", "StableDiffusionXLPipeline", - "StableDiffusionXLModularPipeline", - "StableDiffusionXLAutoPipeline", ] ) _import_structure["stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"] @@ -481,7 +478,6 @@ from .deprecated import KarrasVePipeline, LDMPipeline, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline from .dit import DiTPipeline from .latent_diffusion import LDMSuperResolutionPipeline - from .modular_pipeline import ModularPipeline from .pipeline_utils import ( AudioPipelineOutput, DiffusionPipeline, @@ -706,9 +702,7 @@ StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLInstructPix2PixPipeline, - StableDiffusionXLModularPipeline, StableDiffusionXLPipeline, - StableDiffusionXLAutoPipeline, ) from .stable_video_diffusion import StableVideoDiffusionPipeline from .t2i_adapter import ( diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py deleted file mode 100644 index b50d00dbc219..000000000000 --- a/src/diffusers/pipelines/modular_pipeline.py +++ /dev/null @@ -1,1704 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import traceback -import warnings -from collections import OrderedDict -from dataclasses import dataclass, field -from typing import Any, Dict, List, Tuple, Union - - -import torch -from tqdm.auto import tqdm -import re - -from ..configuration_utils import ConfigMixin -from ..utils import ( - is_accelerate_available, - is_accelerate_version, - logging, -) -from .pipeline_loading_utils import _get_pipeline_class - - -if is_accelerate_available(): - import accelerate - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -MODULAR_PIPELINE_MAPPING = OrderedDict( - [ - ("stable-diffusion-xl", "StableDiffusionXLModularPipeline"), - ] -) - - -@dataclass -class PipelineState: - """ - [`PipelineState`] stores the state of a pipeline. It is used to pass data between pipeline blocks. - """ - - inputs: Dict[str, Any] = field(default_factory=dict) - intermediates: Dict[str, Any] = field(default_factory=dict) - - def add_input(self, key: str, value: Any): - self.inputs[key] = value - - def add_intermediate(self, key: str, value: Any): - self.intermediates[key] = value - - def get_input(self, key: str, default: Any = None) -> Any: - return self.inputs.get(key, default) - - def get_inputs(self, keys: List[str], default: Any = None) -> Dict[str, Any]: - return {key: self.inputs.get(key, default) for key in keys} - - def get_intermediate(self, key: str, default: Any = None) -> Any: - return self.intermediates.get(key, default) - - def get_intermediates(self, keys: List[str], default: Any = None) -> Dict[str, Any]: - return {key: self.intermediates.get(key, default) for key in keys} - - def to_dict(self) -> Dict[str, Any]: - return {**self.__dict__, "inputs": self.inputs, "intermediates": self.intermediates} - - def __repr__(self): - def format_value(v): - if hasattr(v, "shape") and hasattr(v, "dtype"): - return f"Tensor(dtype={v.dtype}, shape={v.shape})" - elif isinstance(v, list) and len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): - return f"[Tensor(dtype={v[0].dtype}, shape={v[0].shape}), ...]" - else: - return repr(v) - - inputs = "\n".join(f" {k}: {format_value(v)}" for k, v in self.inputs.items()) - intermediates = "\n".join(f" {k}: {format_value(v)}" for k, v in self.intermediates.items()) - - return ( - f"PipelineState(\n" - f" inputs={{\n{inputs}\n }},\n" - f" intermediates={{\n{intermediates}\n }}\n" - f")" - ) - - -@dataclass -class BlockState: - """ - Container for block state data with attribute access and formatted representation. - """ - def __init__(self, **kwargs): - for key, value in kwargs.items(): - setattr(self, key, value) - - def __repr__(self): - def format_value(v): - # Handle tensors directly - if hasattr(v, "shape") and hasattr(v, "dtype"): - return f"Tensor(dtype={v.dtype}, shape={v.shape})" - - # Handle lists of tensors - elif isinstance(v, list): - if len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): - shapes = [t.shape for t in v] - return f"List[{len(v)}] of Tensors with shapes {shapes}" - return repr(v) - - # Handle tuples of tensors - elif isinstance(v, tuple): - if len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): - shapes = [t.shape for t in v] - return f"Tuple[{len(v)}] of Tensors with shapes {shapes}" - return repr(v) - - # Handle dicts with tensor values - elif isinstance(v, dict): - if any(hasattr(val, "shape") and hasattr(val, "dtype") for val in v.values()): - shapes = {k: val.shape for k, val in v.items() if hasattr(val, "shape")} - return f"Dict of Tensors with shapes {shapes}" - return repr(v) - - # Default case - return repr(v) - - attributes = "\n".join(f" {k}: {format_value(v)}" for k, v in self.__dict__.items()) - return f"BlockState(\n{attributes}\n)" - - -@dataclass -class InputParam: - name: str - default: Any = None - required: bool = False - description: str = "" - type_hint: Any = Any - - def __repr__(self): - return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>" - -@dataclass -class OutputParam: - name: str - description: str = "" - type_hint: Any = Any - - def __repr__(self): - return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>" - -def format_inputs_short(inputs): - """ - Format input parameters into a string representation, with required params first followed by optional ones. - - Args: - inputs: List of input parameters with 'required' and 'name' attributes, and 'default' for optional params - - Returns: - str: Formatted string of input parameters - """ - required_inputs = [param for param in inputs if param.required] - optional_inputs = [param for param in inputs if not param.required] - - required_str = ", ".join(param.name for param in required_inputs) - optional_str = ", ".join(f"{param.name}={param.default}" for param in optional_inputs) - - inputs_str = required_str - if optional_str: - inputs_str = f"{inputs_str}, {optional_str}" if required_str else optional_str - - return inputs_str - - -def format_intermediates_short(intermediates_inputs: List[InputParam], required_intermediates_inputs: List[str], intermediates_outputs: List[OutputParam]) -> str: - """ - Formats intermediate inputs and outputs of a block into a string representation. - - Args: - intermediates_inputs: List of intermediate input parameters - required_intermediates_inputs: List of required intermediate input names - intermediates_outputs: List of intermediate output parameters - - Returns: - str: Formatted string like: - Intermediates: - - inputs: Required(latents), dtype - - modified: latents # variables that appear in both inputs and outputs - - outputs: images # new outputs only - """ - # Handle inputs - input_parts = [] - for inp in intermediates_inputs: - if inp.name in required_intermediates_inputs: - input_parts.append(f"Required({inp.name})") - else: - input_parts.append(inp.name) - - # Handle modified variables (appear in both inputs and outputs) - inputs_set = {inp.name for inp in intermediates_inputs} - modified_parts = [] - new_output_parts = [] - - for out in intermediates_outputs: - if out.name in inputs_set: - modified_parts.append(out.name) - else: - new_output_parts.append(out.name) - - result = [] - if input_parts: - result.append(f" - inputs: {', '.join(input_parts)}") - if modified_parts: - result.append(f" - modified: {', '.join(modified_parts)}") - if new_output_parts: - result.append(f" - outputs: {', '.join(new_output_parts)}") - - return "\n".join(result) if result else " (none)" - - -def format_params(params: List[Union[InputParam, OutputParam]], header: str = "Args", indent_level: int = 4, max_line_length: int = 115) -> str: - """Format a list of InputParam or OutputParam objects into a readable string representation. - - Args: - params: List of InputParam or OutputParam objects to format - header: Header text to use (e.g. "Args" or "Returns") - indent_level: Number of spaces to indent each parameter line (default: 4) - max_line_length: Maximum length for each line before wrapping (default: 115) - - Returns: - A formatted string representing all parameters - """ - if not params: - return "" - - base_indent = " " * indent_level - param_indent = " " * (indent_level + 4) - desc_indent = " " * (indent_level + 8) - formatted_params = [] - - def get_type_str(type_hint): - if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Union: - types = [t.__name__ if hasattr(t, "__name__") else str(t) for t in type_hint.__args__] - return f"Union[{', '.join(types)}]" - return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint) - - def wrap_text(text: str, indent: str, max_length: int) -> str: - """Wrap text while preserving markdown links and maintaining indentation.""" - words = text.split() - lines = [] - current_line = [] - current_length = 0 - - for word in words: - word_length = len(word) + (1 if current_line else 0) - - if current_line and current_length + word_length > max_length: - lines.append(" ".join(current_line)) - current_line = [word] - current_length = len(word) - else: - current_line.append(word) - current_length += word_length - - if current_line: - lines.append(" ".join(current_line)) - - return f"\n{indent}".join(lines) - - # Add the header - formatted_params.append(f"{base_indent}{header}:") - - for param in params: - # Format parameter name and type - type_str = get_type_str(param.type_hint) if param.type_hint != Any else "" - param_str = f"{param_indent}{param.name} (`{type_str}`" - - # Add optional tag and default value if parameter is an InputParam and optional - if isinstance(param, InputParam): - if not param.required: - param_str += ", *optional*" - if param.default is not None: - param_str += f", defaults to {param.default}" - param_str += "):" - - # Add description on a new line with additional indentation and wrapping - if param.description: - desc = re.sub( - r'\[(.*?)\]\((https?://[^\s\)]+)\)', - r'[\1](\2)', - param.description - ) - wrapped_desc = wrap_text(desc, desc_indent, max_line_length) - param_str += f"\n{desc_indent}{wrapped_desc}" - - formatted_params.append(param_str) - - return "\n\n".join(formatted_params) - -# Then update the original functions to use this combined version: -def format_input_params(input_params: List[InputParam], indent_level: int = 4, max_line_length: int = 115) -> str: - return format_params(input_params, "Args", indent_level, max_line_length) - -def format_output_params(output_params: List[OutputParam], indent_level: int = 4, max_line_length: int = 115) -> str: - return format_params(output_params, "Returns", indent_level, max_line_length) - - - -def make_doc_string(inputs, intermediates_inputs, outputs, description=""): - """ - Generates a formatted documentation string describing the pipeline block's parameters and structure. - - Returns: - str: A formatted string containing information about call parameters, intermediate inputs/outputs, - and final intermediate outputs. - """ - output = "" - - if description: - desc_lines = description.strip().split('\n') - aligned_desc = '\n'.join(' ' + line for line in desc_lines) - output += aligned_desc + "\n\n" - - output += format_input_params(inputs + intermediates_inputs, indent_level=2) - - output += "\n\n" - output += format_output_params(outputs, indent_level=2) - - return output - - -class PipelineBlock: - # YiYi Notes: do we need this? - # pipelie block should set the default value for all expected config/components, so maybe we do not need to explicitly set the list - expected_components = [] - expected_configs = [] - model_name = None - - @property - def description(self) -> str: - """Description of the block. Must be implemented by subclasses.""" - raise NotImplementedError("description method must be implemented in subclasses") - - @property - def inputs(self) -> List[InputParam]: - """List of input parameters. Must be implemented by subclasses.""" - raise NotImplementedError("inputs method must be implemented in subclasses") - - @property - def intermediates_inputs(self) -> List[InputParam]: - """List of intermediate input parameters. Must be implemented by subclasses.""" - raise NotImplementedError("intermediates_inputs method must be implemented in subclasses") - - @property - def intermediates_outputs(self) -> List[OutputParam]: - """List of intermediate output parameters. Must be implemented by subclasses.""" - raise NotImplementedError("intermediates_outputs method must be implemented in subclasses") - - # Adding outputs attributes here for consistency between PipelineBlock/AutoPipelineBlocks/SequentialPipelineBlocks - @property - def outputs(self) -> List[OutputParam]: - return self.intermediates_outputs - - @property - def required_inputs(self) -> List[str]: - input_names = [] - for input_param in self.inputs: - if input_param.required: - input_names.append(input_param.name) - return input_names - - @property - def required_intermediates_inputs(self) -> List[str]: - input_names = [] - for input_param in self.intermediates_inputs: - if input_param.required: - input_names.append(input_param.name) - return input_names - - def __init__(self): - self.components: Dict[str, Any] = {} - self.auxiliaries: Dict[str, Any] = {} - self.configs: Dict[str, Any] = {} - - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - raise NotImplementedError("__call__ method must be implemented in subclasses") - - def __repr__(self): - class_name = self.__class__.__name__ - base_class = self.__class__.__bases__[0].__name__ - - # Format description with proper indentation - desc_lines = self.description.split('\n') - desc = [] - # First line with "Description:" label - desc.append(f" Description: {desc_lines[0]}") - # Subsequent lines with proper indentation - if len(desc_lines) > 1: - desc.extend(f" {line}" for line in desc_lines[1:]) - desc = '\n'.join(desc) + '\n' - - # Components section - expected_components = set(getattr(self, "expected_components", [])) - loaded_components = set(self.components.keys()) - all_components = sorted(expected_components | loaded_components) - - main_components = [] - auxiliary_components = [] - for k in all_components: - component_str = f" - {k}={type(self.components[k]).__name__}" if k in loaded_components else f" - {k}" - if k in getattr(self, "auxiliary_components", []): - auxiliary_components.append(component_str) - else: - main_components.append(component_str) - - components = "Components:\n" + "\n".join(main_components) - if auxiliary_components: - components += "\n Auxiliaries:\n" + "\n".join(auxiliary_components) - - # Configs section - expected_configs = set(getattr(self, "expected_configs", [])) - loaded_configs = set(self.configs.keys()) - all_configs = sorted(expected_configs | loaded_configs) - configs = "Configs:\n" + "\n".join( - f" - {k}={self.configs[k]}" if k in loaded_configs else f" - {k}" - for k in all_configs - ) - - # Inputs section - inputs_str = format_inputs_short(self.inputs) - inputs = "Inputs:\n " + inputs_str - - # Intermediates section - intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs) - intermediates = f"Intermediates:\n{intermediates_str}" - - return ( - f"{class_name}(\n" - f" Class: {base_class}\n" - f"{desc}" - f" {components}\n" - f" {configs}\n" - f" {inputs}\n" - f" {intermediates}\n" - f")" - ) - - - @property - def doc(self): - return make_doc_string(self.inputs, self.intermediates_inputs, self.outputs, self.description) - - - def get_block_state(self, state: PipelineState) -> dict: - """Get all inputs and intermediates in one dictionary""" - data = {} - - # Check inputs - for input_param in self.inputs: - value = state.get_input(input_param.name) - if input_param.required and value is None: - raise ValueError(f"Required input '{input_param.name}' is missing") - data[input_param.name] = value - - # Check intermediates - for input_param in self.intermediates_inputs: - value = state.get_intermediate(input_param.name) - if input_param.required and value is None: - raise ValueError(f"Required intermediate input '{input_param.name}' is missing") - data[input_param.name] = value - - return BlockState(**data) - - def add_block_state(self, state: PipelineState, block_state: BlockState): - for output_param in self.intermediates_outputs: - if not hasattr(block_state, output_param.name): - raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") - state.add_intermediate(output_param.name, getattr(block_state, output_param.name)) - - -def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]: - """ - Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if - current default value is None and new default value is not None. Warns if multiple non-None default values - exist for the same input. - - Args: - named_input_lists: List of tuples containing (block_name, input_param_list) pairs - - Returns: - List[InputParam]: Combined list of unique InputParam objects - """ - combined_dict = {} # name -> InputParam - value_sources = {} # name -> block_name - - for block_name, inputs in named_input_lists: - for input_param in inputs: - if input_param.name in combined_dict: - current_param = combined_dict[input_param.name] - if (current_param.default is not None and - input_param.default is not None and - current_param.default != input_param.default): - warnings.warn( - f"Multiple different default values found for input '{input_param.name}': " - f"{current_param.default} (from block '{value_sources[input_param.name]}') and " - f"{input_param.default} (from block '{block_name}'). Using {current_param.default}." - ) - if current_param.default is None and input_param.default is not None: - combined_dict[input_param.name] = input_param - value_sources[input_param.name] = block_name - else: - combined_dict[input_param.name] = input_param - value_sources[input_param.name] = block_name - - return list(combined_dict.values()) - -def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> List[OutputParam]: - """ - Combines multiple lists of OutputParam objects from different blocks. For duplicate outputs, - keeps the first occurrence of each output name. - - Args: - named_output_lists: List of tuples containing (block_name, output_param_list) pairs - - Returns: - List[OutputParam]: Combined list of unique OutputParam objects - """ - combined_dict = {} # name -> OutputParam - - for block_name, outputs in named_output_lists: - for output_param in outputs: - if output_param.name not in combined_dict: - combined_dict[output_param.name] = output_param - - return list(combined_dict.values()) - - -class AutoPipelineBlocks: - """ - A class that automatically selects a block to run based on the inputs. - - Attributes: - block_classes: List of block classes to be used - block_names: List of prefixes for each block - block_trigger_inputs: List of input names that trigger specific blocks, with None for default - """ - - block_classes = [] - block_names = [] - block_trigger_inputs = [] - - def __init__(self): - blocks = OrderedDict() - for block_name, block_cls in zip(self.block_names, self.block_classes): - blocks[block_name] = block_cls() - self.blocks = blocks - if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)): - raise ValueError(f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same.") - default_blocks = [t for t in self.block_trigger_inputs if t is None] - # can only have 1 or 0 default block, and has to put in the last - # the order of blocksmatters here because the first block with matching trigger will be dispatched - # e.g. blocks = [inpaint, img2img] and block_trigger_inputs = ["mask", "image"] - # if both mask and image are provided, it is inpaint; if only image is provided, it is img2img - if len(default_blocks) > 1 or ( - len(default_blocks) == 1 and self.block_trigger_inputs[-1] is not None - ): - raise ValueError( - f"In {self.__class__.__name__}, exactly one None must be specified as the last element " - "in block_trigger_inputs." - ) - - # Map trigger inputs to block objects - self.trigger_to_block_map = dict(zip(self.block_trigger_inputs, self.blocks.values())) - self.trigger_to_block_name_map = dict(zip(self.block_trigger_inputs, self.blocks.keys())) - self.block_to_trigger_map = dict(zip(self.blocks.keys(), self.block_trigger_inputs)) - - @property - def model_name(self): - return next(iter(self.blocks.values())).model_name - - @property - def description(self): - return "" - - @property - def expected_components(self): - expected_components = [] - for block in self.blocks.values(): - for component in block.expected_components: - if component not in expected_components: - expected_components.append(component) - return expected_components - - @property - def expected_configs(self): - expected_configs = [] - for block in self.blocks.values(): - for config in block.expected_configs: - if config not in expected_configs: - expected_configs.append(config) - return expected_configs - - # YiYi TODO: address the case where multiple blocks have the same component/auxiliary/config; give out warning etc - @property - def components(self): - # Combine components from all blocks - components = {} - for block_name, block in self.blocks.items(): - for key, value in block.components.items(): - # Only update if: - # 1. Key doesn't exist yet in components, OR - # 2. New value is not None - if key not in components or value is not None: - components[key] = value - return components - - @property - def auxiliaries(self): - # Combine auxiliaries from all blocks - auxiliaries = {} - for block_name, block in self.blocks.items(): - auxiliaries.update(block.auxiliaries) - return auxiliaries - - @property - def configs(self): - # Combine configs from all blocks - configs = {} - for block_name, block in self.blocks.items(): - configs.update(block.configs) - return configs - - @property - def required_inputs(self) -> List[str]: - first_block = next(iter(self.blocks.values())) - required_by_all = set(getattr(first_block, "required_inputs", set())) - - # Intersect with required inputs from all other blocks - for block in list(self.blocks.values())[1:]: - block_required = set(getattr(block, "required_inputs", set())) - required_by_all.intersection_update(block_required) - - return list(required_by_all) - - @property - def required_intermediates_inputs(self) -> List[str]: - first_block = next(iter(self.blocks.values())) - required_by_all = set(getattr(first_block, "required_intermediates_inputs", set())) - - # Intersect with required inputs from all other blocks - for block in list(self.blocks.values())[1:]: - block_required = set(getattr(block, "required_intermediates_inputs", set())) - required_by_all.intersection_update(block_required) - - return list(required_by_all) - - - # YiYi TODO: add test for this - @property - def inputs(self) -> List[Tuple[str, Any]]: - named_inputs = [(name, block.inputs) for name, block in self.blocks.items()] - combined_inputs = combine_inputs(*named_inputs) - # mark Required inputs only if that input is required by all the blocks - for input_param in combined_inputs: - if input_param.name in self.required_inputs: - input_param.required = True - else: - input_param.required = False - return combined_inputs - - - @property - def intermediates_inputs(self) -> List[str]: - named_inputs = [(name, block.intermediates_inputs) for name, block in self.blocks.items()] - combined_inputs = combine_inputs(*named_inputs) - # mark Required inputs only if that input is required by all the blocks - for input_param in combined_inputs: - if input_param.name in self.required_intermediates_inputs: - input_param.required = True - else: - input_param.required = False - return combined_inputs - - @property - def intermediates_outputs(self) -> List[str]: - named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] - combined_outputs = combine_outputs(*named_outputs) - return combined_outputs - - @property - def outputs(self) -> List[str]: - named_outputs = [(name, block.outputs) for name, block in self.blocks.items()] - combined_outputs = combine_outputs(*named_outputs) - return combined_outputs - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - # Find default block first (if any) - - block = self.trigger_to_block_map.get(None) - for input_name in self.block_trigger_inputs: - if input_name is not None and state.get_input(input_name) is not None: - block = self.trigger_to_block_map[input_name] - break - elif input_name is not None and state.get_intermediate(input_name) is not None: - block = self.trigger_to_block_map[input_name] - break - - if block is None: - logger.warning(f"skipping auto block: {self.__class__.__name__}") - return pipeline, state - - try: - logger.info(f"Running block: {block.__class__.__name__}, trigger: {input_name}") - return block(pipeline, state) - except Exception as e: - error_msg = ( - f"\nError in block: {block.__class__.__name__}\n" - f"Error details: {str(e)}\n" - f"Traceback:\n{traceback.format_exc()}" - ) - logger.error(error_msg) - raise - - def _get_trigger_inputs(self): - """ - Returns a set of all unique trigger input values found in the blocks. - Returns: Set[str] containing all unique block_trigger_inputs values - """ - def fn_recursive_get_trigger(blocks): - trigger_values = set() - - if blocks is not None: - for name, block in blocks.items(): - # Check if current block has trigger inputs(i.e. auto block) - if hasattr(block, 'block_trigger_inputs') and block.block_trigger_inputs is not None: - # Add all non-None values from the trigger inputs list - trigger_values.update(t for t in block.block_trigger_inputs if t is not None) - - # If block has blocks, recursively check them - if hasattr(block, 'blocks'): - nested_triggers = fn_recursive_get_trigger(block.blocks) - trigger_values.update(nested_triggers) - - return trigger_values - - trigger_inputs = set(self.block_trigger_inputs) - trigger_inputs.update(fn_recursive_get_trigger(self.blocks)) - - return trigger_inputs - - @property - def trigger_inputs(self): - return self._get_trigger_inputs() - - def __repr__(self): - class_name = self.__class__.__name__ - base_class = self.__class__.__bases__[0].__name__ - header = ( - f"{class_name}(\n Class: {base_class}\n" - if base_class and base_class != "object" - else f"{class_name}(\n" - ) - - - if self.trigger_inputs: - header += "\n" - header += " " + "=" * 100 + "\n" - header += " This pipeline contains blocks that are selected at runtime based on inputs.\n" - header += f" Trigger Inputs: {self.trigger_inputs}\n" - # Get first trigger input as example - example_input = next(t for t in self.trigger_inputs if t is not None) - header += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n" - header += " " + "=" * 100 + "\n\n" - - # Format description with proper indentation - desc_lines = self.description.split('\n') - desc = [] - # First line with "Description:" label - desc.append(f" Description: {desc_lines[0]}") - # Subsequent lines with proper indentation - if len(desc_lines) > 1: - desc.extend(f" {line}" for line in desc_lines[1:]) - desc = '\n'.join(desc) + '\n' - - # Components section - expected_components = set(getattr(self, "expected_components", [])) - loaded_components = set(self.components.keys()) - all_components = sorted(expected_components | loaded_components) - components_str = " Components:\n" + "\n".join( - f" - {k}={type(self.components[k]).__name__}" if k in loaded_components else f" - {k}" - for k in all_components - ) - - # Auxiliaries section - auxiliaries_str = " Auxiliaries:\n" + "\n".join( - f" - {k}={type(v).__name__}" for k, v in self.auxiliaries.items() - ) - - # Configs section - expected_configs = set(getattr(self, "expected_configs", [])) - loaded_configs = set(self.configs.keys()) - all_configs = sorted(expected_configs | loaded_configs) - configs_str = " Configs:\n" + "\n".join( - f" - {k}={v}" if k in loaded_configs else f" - {k}" for k, v in self.configs.items() - ) - - blocks_str = " Blocks:\n" - for i, (name, block) in enumerate(self.blocks.items()): - # Get trigger input for this block - trigger = None - if hasattr(self, 'block_to_trigger_map'): - trigger = self.block_to_trigger_map.get(name) - # Format the trigger info - if trigger is None: - trigger_str = "[default]" - elif isinstance(trigger, (list, tuple)): - trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]" - else: - trigger_str = f"[trigger: {trigger}]" - # For AutoPipelineBlocks, add bullet points - blocks_str += f" • {name} {trigger_str} ({block.__class__.__name__})\n" - else: - # For SequentialPipelineBlocks, show execution order - blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" - - # Add block description - desc_lines = block.description.split('\n') - indented_desc = desc_lines[0] - if len(desc_lines) > 1: - indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) - blocks_str += f" Description: {indented_desc}\n" - - # Format inputs - inputs_str = format_inputs_short(block.inputs) - blocks_str += f" inputs: {inputs_str}\n" - - # Format intermediates - intermediates_str = format_intermediates_short( - block.intermediates_inputs, - block.required_intermediates_inputs, - block.intermediates_outputs - ) - if intermediates_str != " (none)": - blocks_str += " intermediates:\n" - indented_intermediates = "\n".join( - " " + line for line in intermediates_str.split("\n") - ) - blocks_str += f"{indented_intermediates}\n" - blocks_str += "\n" - - inputs_str = format_inputs_short(self.inputs) - inputs_str = " Inputs:\n " + inputs_str - outputs = [out.name for out in self.outputs] - - intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs) - intermediates_str = ( - "\n Intermediates:\n" - f"{intermediates_str}\n" - f" - final outputs: {', '.join(outputs)}" - ) - - return ( - f"{header}\n" - f"{desc}" - f"{components_str}\n" - f"{auxiliaries_str}\n" - f"{configs_str}\n" - f"{blocks_str}\n" - f"{inputs_str}\n" - f"{intermediates_str}\n" - f")" - ) - - @property - def doc(self): - return make_doc_string(self.inputs, self.intermediates_inputs, self.outputs, self.description) - -class SequentialPipelineBlocks: - """ - A class that combines multiple pipeline block classes into one. When called, it will call each block in sequence. - """ - block_classes = [] - block_names = [] - - @property - def model_name(self): - return next(iter(self.blocks.values())).model_name - - @property - def description(self): - return "" - - @property - def expected_components(self): - expected_components = [] - for block in self.blocks.values(): - for component in block.expected_components: - if component not in expected_components: - expected_components.append(component) - return expected_components - - @property - def expected_configs(self): - expected_configs = [] - for block in self.blocks.values(): - for config in block.expected_configs: - if config not in expected_configs: - expected_configs.append(config) - return expected_configs - - @classmethod - def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "SequentialPipelineBlocks": - """Creates a SequentialPipelineBlocks instance from a dictionary of blocks. - - Args: - blocks_dict: Dictionary mapping block names to block instances - - Returns: - A new SequentialPipelineBlocks instance - """ - instance = cls() - instance.block_classes = [block.__class__ for block in blocks_dict.values()] - instance.block_names = list(blocks_dict.keys()) - instance.blocks = blocks_dict - return instance - - def __init__(self): - blocks = OrderedDict() - for block_name, block_cls in zip(self.block_names, self.block_classes): - blocks[block_name] = block_cls() - self.blocks = blocks - - # YiYi TODO: address the case where multiple blocks have the same component/auxiliary/config; give out warning etc - @property - def components(self): - # Combine components from all blocks - components = {} - for block_name, block in self.blocks.items(): - for key, value in block.components.items(): - # Only update if: - # 1. Key doesn't exist yet in components, OR - # 2. New value is not None - if key not in components or value is not None: - components[key] = value - return components - - @property - def auxiliaries(self): - # Combine auxiliaries from all blocks - auxiliaries = {} - for block_name, block in self.blocks.items(): - auxiliaries.update(block.auxiliaries) - return auxiliaries - - @property - def configs(self): - # Combine configs from all blocks - configs = {} - for block_name, block in self.blocks.items(): - configs.update(block.configs) - return configs - - @property - def required_inputs(self) -> List[str]: - # Get the first block from the dictionary - first_block = next(iter(self.blocks.values())) - required_by_any = set(getattr(first_block, "required_inputs", set())) - - # Union with required inputs from all other blocks - for block in list(self.blocks.values())[1:]: - block_required = set(getattr(block, "required_inputs", set())) - required_by_any.update(block_required) - - return list(required_by_any) - - @property - def required_intermediates_inputs(self) -> List[str]: - required_intermediates_inputs = [] - for input_param in self.intermediates_inputs: - if input_param.required: - required_intermediates_inputs.append(input_param.name) - return required_intermediates_inputs - - # YiYi TODO: add test for this - @property - def inputs(self) -> List[Tuple[str, Any]]: - named_inputs = [(name, block.inputs) for name, block in self.blocks.items()] - combined_inputs = combine_inputs(*named_inputs) - # mark Required inputs only if that input is required any of the blocks - for input_param in combined_inputs: - if input_param.name in self.required_inputs: - input_param.required = True - else: - input_param.required = False - return combined_inputs - - @property - def intermediates_inputs(self) -> List[str]: - inputs = [] - outputs = set() - - # Go through all blocks in order - for block in self.blocks.values(): - # Add inputs that aren't in outputs yet - inputs.extend(input_name for input_name in block.intermediates_inputs if input_name.name not in outputs) - - # Only add outputs if the block cannot be skipped - should_add_outputs = True - if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs: - should_add_outputs = False - - if should_add_outputs: - # Add this block's outputs - block_intermediates_outputs = [out.name for out in block.intermediates_outputs] - outputs.update(block_intermediates_outputs) - return inputs - - @property - def intermediates_outputs(self) -> List[str]: - named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] - combined_outputs = combine_outputs(*named_outputs) - return combined_outputs - - @property - def outputs(self) -> List[str]: - return next(reversed(self.blocks.values())).intermediates_outputs - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - for block_name, block in self.blocks.items(): - try: - pipeline, state = block(pipeline, state) - except Exception as e: - error_msg = ( - f"\nError in block: ({block_name}, {block.__class__.__name__})\n" - f"Error details: {str(e)}\n" - f"Traceback:\n{traceback.format_exc()}" - ) - logger.error(error_msg) - raise - return pipeline, state - - def _get_trigger_inputs(self): - """ - Returns a set of all unique trigger input values found in the blocks. - Returns: Set[str] containing all unique block_trigger_inputs values - """ - def fn_recursive_get_trigger(blocks): - trigger_values = set() - - if blocks is not None: - for name, block in blocks.items(): - # Check if current block has trigger inputs(i.e. auto block) - if hasattr(block, 'block_trigger_inputs') and block.block_trigger_inputs is not None: - # Add all non-None values from the trigger inputs list - trigger_values.update(t for t in block.block_trigger_inputs if t is not None) - - # If block has blocks, recursively check them - if hasattr(block, 'blocks'): - nested_triggers = fn_recursive_get_trigger(block.blocks) - trigger_values.update(nested_triggers) - - return trigger_values - - return fn_recursive_get_trigger(self.blocks) - - @property - def trigger_inputs(self): - return self._get_trigger_inputs() - - def _traverse_trigger_blocks(self, trigger_inputs): - # Convert trigger_inputs to a set for easier manipulation - active_triggers = set(trigger_inputs) - - def fn_recursive_traverse(block, block_name, active_triggers): - result_blocks = OrderedDict() - - # sequential or PipelineBlock - if not hasattr(block, 'block_trigger_inputs'): - if hasattr(block, 'blocks'): - # sequential - for block_name, block in block.blocks.items(): - blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers) - result_blocks.update(blocks_to_update) - else: - # PipelineBlock - result_blocks[block_name] = block - # Add this block's output names to active triggers if defined - if hasattr(block, 'outputs'): - active_triggers.update(out.name for out in block.outputs) - return result_blocks - - # auto - else: - # Find first block_trigger_input that matches any value in our active_triggers - this_block = None - matching_trigger = None - for trigger_input in block.block_trigger_inputs: - if trigger_input is not None and trigger_input in active_triggers: - this_block = block.trigger_to_block_map[trigger_input] - matching_trigger = trigger_input - break - - # If no matches found, try to get the default (None) block - if this_block is None and None in block.block_trigger_inputs: - this_block = block.trigger_to_block_map[None] - matching_trigger = None - - if this_block is not None: - # sequential/auto - if hasattr(this_block, 'blocks'): - result_blocks.update(fn_recursive_traverse(this_block, block_name, active_triggers)) - else: - # PipelineBlock - result_blocks[block_name] = this_block - # Add this block's output names to active triggers if defined - if hasattr(this_block, 'outputs'): - active_triggers.update(out.name for out in this_block.outputs) - - return result_blocks - - all_blocks = OrderedDict() - for block_name, block in self.blocks.items(): - blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers) - all_blocks.update(blocks_to_update) - return all_blocks - - def get_execution_blocks(self, *trigger_inputs): - trigger_inputs_all = self.trigger_inputs - - if trigger_inputs is not None: - - if not isinstance(trigger_inputs, (list, tuple, set)): - trigger_inputs = [trigger_inputs] - invalid_inputs = [x for x in trigger_inputs if x not in trigger_inputs_all] - if invalid_inputs: - logger.warning( - f"The following trigger inputs will be ignored as they are not supported: {invalid_inputs}" - ) - trigger_inputs = [x for x in trigger_inputs if x in trigger_inputs_all] - - if trigger_inputs is None: - if None in trigger_inputs_all: - trigger_inputs = [None] - else: - trigger_inputs = [trigger_inputs_all[0]] - blocks_triggered = self._traverse_trigger_blocks(trigger_inputs) - return SequentialPipelineBlocks.from_blocks_dict(blocks_triggered) - - def __repr__(self): - class_name = self.__class__.__name__ - base_class = self.__class__.__bases__[0].__name__ - header = ( - f"{class_name}(\n Class: {base_class}\n" - if base_class and base_class != "object" - else f"{class_name}(\n" - ) - - - if self.trigger_inputs: - header += "\n" - header += " " + "=" * 100 + "\n" - header += " This pipeline contains blocks that are selected at runtime based on inputs.\n" - header += f" Trigger Inputs: {self.trigger_inputs}\n" - # Get first trigger input as example - example_input = next(t for t in self.trigger_inputs if t is not None) - header += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n" - header += " " + "=" * 100 + "\n\n" - - # Format description with proper indentation - desc_lines = self.description.split('\n') - desc = [] - # First line with "Description:" label - desc.append(f" Description: {desc_lines[0]}") - # Subsequent lines with proper indentation - if len(desc_lines) > 1: - desc.extend(f" {line}" for line in desc_lines[1:]) - desc = '\n'.join(desc) + '\n' - - # Components section - expected_components = set(getattr(self, "expected_components", [])) - loaded_components = set(self.components.keys()) - all_components = sorted(expected_components | loaded_components) - components_str = " Components:\n" + "\n".join( - f" - {k}={type(self.components[k]).__name__}" if k in loaded_components else f" - {k}" - for k in all_components - ) - - # Auxiliaries section - auxiliaries_str = " Auxiliaries:\n" + "\n".join( - f" - {k}={type(v).__name__}" for k, v in self.auxiliaries.items() - ) - - # Configs section - expected_configs = set(getattr(self, "expected_configs", [])) - loaded_configs = set(self.configs.keys()) - all_configs = sorted(expected_configs | loaded_configs) - configs_str = " Configs:\n" + "\n".join( - f" - {k}={self.configs[k]}" if k in loaded_configs else f" - {k}" for k in all_configs - ) - - blocks_str = " Blocks:\n" - for i, (name, block) in enumerate(self.blocks.items()): - # Get trigger input for this block - trigger = None - if hasattr(self, 'block_to_trigger_map'): - trigger = self.block_to_trigger_map.get(name) - # Format the trigger info - if trigger is None: - trigger_str = "[default]" - elif isinstance(trigger, (list, tuple)): - trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]" - else: - trigger_str = f"[trigger: {trigger}]" - # For AutoPipelineBlocks, add bullet points - blocks_str += f" • {name} {trigger_str} ({block.__class__.__name__})\n" - else: - # For SequentialPipelineBlocks, show execution order - blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" - - # Add block description - desc_lines = block.description.split('\n') - indented_desc = desc_lines[0] - if len(desc_lines) > 1: - indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) - blocks_str += f" Description: {indented_desc}\n" - - # Format inputs - inputs_str = format_inputs_short(block.inputs) - blocks_str += f" inputs: {inputs_str}\n" - - # Format intermediates - intermediates_str = format_intermediates_short( - block.intermediates_inputs, - block.required_intermediates_inputs, - block.intermediates_outputs - ) - if intermediates_str != " (none)": - blocks_str += " intermediates:\n" - indented_intermediates = "\n".join( - " " + line for line in intermediates_str.split("\n") - ) - blocks_str += f"{indented_intermediates}\n" - blocks_str += "\n" - - inputs_str = format_inputs_short(self.inputs) - inputs_str = " Inputs:\n " + inputs_str - outputs = [out.name for out in self.outputs] - - intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs) - intermediates_str = ( - "\n Intermediates:\n" - f"{intermediates_str}\n" - f" - final outputs: {', '.join(outputs)}" - ) - - return ( - f"{header}\n" - f"{desc}" - f"{components_str}\n" - f"{auxiliaries_str}\n" - f"{configs_str}\n" - f"{blocks_str}\n" - f"{inputs_str}\n" - f"{intermediates_str}\n" - f")" - ) - - @property - def doc(self): - return make_doc_string(self.inputs, self.intermediates_inputs, self.outputs, self.description) - -class ModularPipeline(ConfigMixin): - """ - Base class for all Modular pipelines. - - """ - - config_name = "model_index.json" - _exclude_from_cpu_offload = [] - - def __init__(self, block): - self.pipeline_block = block - - # add default components from pipeline_block (e.g. guider) - for key, value in block.components.items(): - setattr(self, key, value) - - # add default configs from pipeline_block (e.g. force_zeros_for_empty_prompt) - self.register_to_config(**block.configs) - - # add default auxiliaries from pipeline_block (e.g. image_processor) - for key, value in block.auxiliaries.items(): - setattr(self, key, value) - - @classmethod - def from_block(cls, block): - modular_pipeline_class_name = MODULAR_PIPELINE_MAPPING[block.model_name] - modular_pipeline_class = _get_pipeline_class(cls, class_name=modular_pipeline_class_name) - - return modular_pipeline_class(block) - - @property - def device(self) -> torch.device: - r""" - Returns: - `torch.device`: The torch device on which the pipeline is located. - """ - modules = self.components.values() - modules = [m for m in modules if isinstance(m, torch.nn.Module)] - - for module in modules: - return module.device - - return torch.device("cpu") - - @property - # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline._execution_device - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - [`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from - Accelerate's module hooks. - """ - for name, model in self.components.items(): - if not isinstance(model, torch.nn.Module) or name in self._exclude_from_cpu_offload: - continue - - if not hasattr(model, "_hf_hook"): - return self.device - for module in model.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - - def get_execution_blocks(self, *trigger_inputs): - return self.pipeline_block.get_execution_blocks(*trigger_inputs) - - @property - def dtype(self) -> torch.dtype: - r""" - Returns: - `torch.dtype`: The torch dtype on which the pipeline is located. - """ - modules = self.components.values() - modules = [m for m in modules if isinstance(m, torch.nn.Module)] - - for module in modules: - return module.dtype - - return torch.float32 - - @property - def expected_components(self): - return self.pipeline_block.expected_components - - @property - def expected_configs(self): - return self.pipeline_block.expected_configs - - @property - def components(self): - components = {} - for name in self.expected_components: - if hasattr(self, name): - components[name] = getattr(self, name) - return components - - # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline.progress_bar - def progress_bar(self, iterable=None, total=None): - if not hasattr(self, "_progress_bar_config"): - self._progress_bar_config = {} - elif not isinstance(self._progress_bar_config, dict): - raise ValueError( - f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." - ) - - if iterable is not None: - return tqdm(iterable, **self._progress_bar_config) - elif total is not None: - return tqdm(total=total, **self._progress_bar_config) - else: - raise ValueError("Either `total` or `iterable` has to be defined.") - - # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline.set_progress_bar_config - def set_progress_bar_config(self, **kwargs): - self._progress_bar_config = kwargs - - def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs): - """ - Run one or more blocks in sequence, optionally you can pass a previous pipeline state. - """ - if state is None: - state = PipelineState() - - # Make a copy of the input kwargs - input_params = kwargs.copy() - - default_params = self.default_call_parameters - - # Add inputs to state, using defaults if not provided in the kwargs or the state - # if same input already in the state, will override it if provided in the kwargs - - intermediates_inputs = [inp.name for inp in self.pipeline_block.intermediates_inputs] - for name, default in default_params.items(): - if name in input_params: - if name not in intermediates_inputs: - state.add_input(name, input_params.pop(name)) - else: - state.add_input(name, input_params[name]) - elif name not in state.inputs: - state.add_input(name, default) - - for name in intermediates_inputs: - if name in input_params: - state.add_intermediate(name, input_params.pop(name)) - - # Warn about unexpected inputs - if len(input_params) > 0: - logger.warning(f"Unexpected input '{input_params.keys()}' provided. This input will be ignored.") - # Run the pipeline - with torch.no_grad(): - try: - pipeline, state = self.pipeline_block(self, state) - except Exception: - error_msg = f"Error in block: ({self.pipeline_block.__class__.__name__}):\n" - logger.error(error_msg) - raise - - if output is None: - return state - - - elif isinstance(output, str): - return state.get_intermediate(output) - - elif isinstance(output, (list, tuple)): - return state.get_intermediates(output) - else: - raise ValueError(f"Output '{output}' is not a valid output type") - - def update_states(self, **kwargs): - """ - Update components and configs after instance creation. Auxiliaries (e.g. image_processor) should be defined for - each pipeline block, does not need to be updated by users. Logs if existing non-None components are being - overwritten. - - Args: - kwargs (dict): Keyword arguments to update the states. - """ - - for component_name in self.expected_components: - if component_name in kwargs: - if hasattr(self, component_name) and getattr(self, component_name) is not None: - current_component = getattr(self, component_name) - new_component = kwargs[component_name] - - if not isinstance(new_component, current_component.__class__): - logger.info( - f"Overwriting existing component '{component_name}' " - f"(type: {current_component.__class__.__name__}) " - f"with type: {new_component.__class__.__name__})" - ) - elif isinstance(current_component, torch.nn.Module): - if id(current_component) != id(new_component): - logger.info( - f"Overwriting existing component '{component_name}' " - f"(type: {type(current_component).__name__}) " - f"with new value (type: {type(new_component).__name__})" - ) - - setattr(self, component_name, kwargs.pop(component_name)) - - configs_to_add = {} - for config_name in self.expected_configs: - if config_name in kwargs: - configs_to_add[config_name] = kwargs.pop(config_name) - self.register_to_config(**configs_to_add) - - @property - def default_call_parameters(self) -> Dict[str, Any]: - params = {} - for input_param in self.pipeline_block.inputs: - params[input_param.name] = input_param.default - return params - - def __repr__(self): - output = "ModularPipeline:\n" - output += "==============================\n\n" - - block = self.pipeline_block - - # List the pipeline block structure first - output += "Pipeline Block:\n" - output += "--------------\n" - if hasattr(block, "blocks"): - output += f"{block.__class__.__name__}\n" - base_class = block.__class__.__bases__[0].__name__ - output += f" (Class: {base_class})\n" if base_class != "object" else "\n" - for sub_block_name, sub_block in block.blocks.items(): - if hasattr(block, "block_trigger_inputs"): - trigger_input = block.block_to_trigger_map[sub_block_name] - trigger_info = f" [trigger: {trigger_input}]" if trigger_input is not None else " [default]" - output += f" • {sub_block_name} ({sub_block.__class__.__name__}){trigger_info}\n" - else: - output += f" • {sub_block_name} ({sub_block.__class__.__name__})\n" - else: - output += f"{block.__class__.__name__}\n" - output += "\n" - - # List the components registered in the pipeline - output += "Registered Components:\n" - output += "----------------------\n" - for name, component in self.components.items(): - output += f"{name}: {type(component).__name__}" - if hasattr(component, "dtype") and hasattr(component, "device"): - output += f" (dtype={component.dtype}, device={component.device})" - output += "\n" - output += "\n" - - # List the configs registered in the pipeline - output += "Registered Configs:\n" - output += "------------------\n" - for name, config in self.config.items(): - output += f"{name}: {config!r}\n" - output += "\n" - - # Add auto blocks section - if hasattr(block, "trigger_inputs") and block.trigger_inputs: - output += "------------------\n" - output += "This pipeline contains blocks that are selected at runtime based on inputs.\n\n" - output += f"Trigger Inputs: {block.trigger_inputs}\n" - # Get first trigger input as example - example_input = next(t for t in block.trigger_inputs if t is not None) - output += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n" - output += "Check `.doc` of returned object for more information.\n\n" - - # List the call parameters - full_doc = self.pipeline_block.doc - if "------------------------" in full_doc: - full_doc = full_doc.split("------------------------")[0].rstrip() - output += full_doc - - return output - - # YiYi TO-DO: try to unify the to method with the one in DiffusionPipeline - # Modified from diffusers.pipelines.pipeline_utils.DiffusionPipeline.to - def to(self, *args, **kwargs): - r""" - Performs Pipeline dtype and/or device conversion. A torch.dtype and torch.device are inferred from the - arguments of `self.to(*args, **kwargs).` - - - - If the pipeline already has the correct torch.dtype and torch.device, then it is returned as is. Otherwise, - the returned pipeline is a copy of self with the desired torch.dtype and torch.device. - - - - - Here are the ways to call `to`: - - - `to(dtype, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified - [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype) - - `to(device, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified - [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) - - `to(device=None, dtype=None, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the - specified [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) and - [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype) - - Arguments: - dtype (`torch.dtype`, *optional*): - Returns a pipeline with the specified - [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype) - device (`torch.Device`, *optional*): - Returns a pipeline with the specified - [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) - silence_dtype_warnings (`str`, *optional*, defaults to `False`): - Whether to omit warnings if the target `dtype` is not compatible with the target `device`. - - Returns: - [`DiffusionPipeline`]: The pipeline converted to specified `dtype` and/or `dtype`. - """ - dtype = kwargs.pop("dtype", None) - device = kwargs.pop("device", None) - silence_dtype_warnings = kwargs.pop("silence_dtype_warnings", False) - - dtype_arg = None - device_arg = None - if len(args) == 1: - if isinstance(args[0], torch.dtype): - dtype_arg = args[0] - else: - device_arg = torch.device(args[0]) if args[0] is not None else None - elif len(args) == 2: - if isinstance(args[0], torch.dtype): - raise ValueError( - "When passing two arguments, make sure the first corresponds to `device` and the second to `dtype`." - ) - device_arg = torch.device(args[0]) if args[0] is not None else None - dtype_arg = args[1] - elif len(args) > 2: - raise ValueError("Please make sure to pass at most two arguments (`device` and `dtype`) `.to(...)`") - - if dtype is not None and dtype_arg is not None: - raise ValueError( - "You have passed `dtype` both as an argument and as a keyword argument. Please only pass one of the two." - ) - - dtype = dtype or dtype_arg - - if device is not None and device_arg is not None: - raise ValueError( - "You have passed `device` both as an argument and as a keyword argument. Please only pass one of the two." - ) - - device = device or device_arg - - # throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU. - def module_is_sequentially_offloaded(module): - if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"): - return False - - return hasattr(module, "_hf_hook") and ( - isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook) - or hasattr(module._hf_hook, "hooks") - and isinstance(module._hf_hook.hooks[0], accelerate.hooks.AlignDevicesHook) - ) - - def module_is_offloaded(module): - if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"): - return False - - return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload) - - # .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer - pipeline_is_sequentially_offloaded = any( - module_is_sequentially_offloaded(module) for _, module in self.components.items() - ) - if pipeline_is_sequentially_offloaded and device and torch.device(device).type == "cuda": - raise ValueError( - "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading." - ) - - is_pipeline_device_mapped = hasattr(self, "hf_device_map") and self.hf_device_map is not None and len(self.hf_device_map) > 1 - if is_pipeline_device_mapped: - raise ValueError( - "It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` first and then call `to()`." - ) - - # Display a warning in this case (the operation succeeds but the benefits are lost) - pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items()) - if pipeline_is_offloaded and device and torch.device(device).type == "cuda": - logger.warning( - f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading." - ) - - modules = [m for m in self.components.values() if isinstance(m, torch.nn.Module)] - - is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded - for module in modules: - is_loaded_in_8bit = hasattr(module, "is_loaded_in_8bit") and module.is_loaded_in_8bit - - if is_loaded_in_8bit and dtype is not None: - logger.warning( - f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {dtype} is not yet supported. Module is still in 8bit precision." - ) - - if is_loaded_in_8bit and device is not None: - logger.warning( - f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {dtype} via `.to()` is not yet supported. Module is still on {module.device}." - ) - else: - module.to(device, dtype) - - if ( - module.dtype == torch.float16 - and str(device) in ["cpu"] - and not silence_dtype_warnings - and not is_offloaded - ): - logger.warning( - "Pipelines loaded with `dtype=torch.float16` cannot run with `cpu` device. It" - " is not recommended to move them to `cpu` as running them will fail. Please make" - " sure to use an accelerator to run the pipeline in inference, due to the lack of" - " support for`float16` operations on this device in PyTorch. Please, remove the" - " `torch_dtype=torch.float16` argument, or use another device for inference." - ) - return self diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 2b8afeffa00a..8b422798713f 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -331,6 +331,20 @@ def maybe_raise_or_warn( ) +# a simpler version of get_class_obj_and_candidates, it won't work with custom code +def simple_get_class_obj(library_name, class_name): + from diffusers import pipelines + is_pipeline_module = hasattr(pipelines, library_name) + + if is_pipeline_module: + pipeline_module = getattr(pipelines, library_name) + class_obj = getattr(pipeline_module, class_name) + else: + library = importlib.import_module(library_name) + class_obj = getattr(library, class_name) + + return class_obj + def get_class_obj_and_candidates( library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None ): @@ -412,7 +426,7 @@ def _get_pipeline_class( revision=revision, ) - if class_obj.__name__ != "DiffusionPipeline" and class_obj.__name__ != "ModularPipeline": + if class_obj.__name__ != "DiffusionPipeline": return class_obj diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0]) @@ -839,7 +853,10 @@ def _fetch_class_library_tuple(module): library = not_compiled_module.__module__ # retrieve class_name - class_name = not_compiled_module.__class__.__name__ + if isinstance(not_compiled_module, type): + class_name = not_compiled_module.__name__ + else: + class_name = not_compiled_module.__class__.__name__ return (library, class_name) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 8f9486aa6386..49575e99763a 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1948,9 +1948,10 @@ def from_pipe(cls, pipeline, **kwargs): f"{'' if k.startswith('_') else '_'}{k}": v for k, v in original_config.items() if k not in pipeline_kwargs } + optional_components = pipeline._optional_components if hasattr(pipeline, "_optional_components") and pipeline._optional_components else [] missing_modules = ( set(expected_modules) - - set(pipeline._optional_components) + - set(optional_components) - set(pipeline_kwargs.keys()) - set(true_optional_modules) ) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/__init__.py b/src/diffusers/pipelines/stable_diffusion_xl/__init__.py index 584b260eaaa8..8088fbcfceba 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/__init__.py @@ -29,18 +29,6 @@ _import_structure["pipeline_stable_diffusion_xl_img2img"] = ["StableDiffusionXLImg2ImgPipeline"] _import_structure["pipeline_stable_diffusion_xl_inpaint"] = ["StableDiffusionXLInpaintPipeline"] _import_structure["pipeline_stable_diffusion_xl_instruct_pix2pix"] = ["StableDiffusionXLInstructPix2PixPipeline"] - _import_structure["pipeline_stable_diffusion_xl_modular"] = [ - "StableDiffusionXLControlNetDenoiseStep", - "StableDiffusionXLDecodeLatentsStep", - "StableDiffusionXLDenoiseStep", - "StableDiffusionXLInputStep", - "StableDiffusionXLModularPipeline", - "StableDiffusionXLPrepareAdditionalConditioningStep", - "StableDiffusionXLPrepareLatentsStep", - "StableDiffusionXLSetTimestepsStep", - "StableDiffusionXLTextEncoderStep", - "StableDiffusionXLAutoPipeline", - ] if is_transformers_available() and is_flax_available(): from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState @@ -60,18 +48,6 @@ from .pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipeline from .pipeline_stable_diffusion_xl_inpaint import StableDiffusionXLInpaintPipeline from .pipeline_stable_diffusion_xl_instruct_pix2pix import StableDiffusionXLInstructPix2PixPipeline - from .pipeline_stable_diffusion_xl_modular import ( - StableDiffusionXLControlNetDenoiseStep, - StableDiffusionXLDecodeLatentsStep, - StableDiffusionXLDenoiseStep, - StableDiffusionXLInputStep, - StableDiffusionXLModularPipeline, - StableDiffusionXLPrepareAdditionalConditioningStep, - StableDiffusionXLPrepareLatentsStep, - StableDiffusionXLSetTimestepsStep, - StableDiffusionXLTextEncoderStep, - StableDiffusionXLAutoPipeline, - ) try: if not (is_transformers_available() and is_flax_available()): diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py deleted file mode 100644 index f743f442cc40..000000000000 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ /dev/null @@ -1,3909 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Any, List, Optional, Tuple, Union, Dict - -import PIL -import torch -from collections import OrderedDict - -from ...guider import CFGGuider -from ...image_processor import VaeImageProcessor, PipelineImageInput -from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin -from ...models import ControlNetModel, ImageProjection -from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor -from ...models.lora import adjust_lora_scale_text_encoder -from ...utils import ( - USE_PEFT_BACKEND, - logging, - scale_lora_layers, - unscale_lora_layers, -) -from ...utils.torch_utils import is_compiled_module, randn_tensor -from ..controlnet.multicontrolnet import MultiControlNetModel -from ..modular_pipeline import ( - AutoPipelineBlocks, - ModularPipeline, - PipelineBlock, - PipelineState, - InputParam, - OutputParam, - SequentialPipelineBlocks, -) -from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin -from .pipeline_output import ( - StableDiffusionXLPipelineOutput, -) - -import numpy as np - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps -def retrieve_timesteps( - scheduler, - num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, - timesteps: Optional[List[int]] = None, - sigmas: Optional[List[float]] = None, - **kwargs, -): - r""" - Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles - custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. - - Args: - scheduler (`SchedulerMixin`): - The scheduler to get timesteps from. - num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` - must be `None`. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - timesteps (`List[int]`, *optional*): - Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, - `num_inference_steps` and `sigmas` must be `None`. - sigmas (`List[float]`, *optional*): - Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, - `num_inference_steps` and `timesteps` must be `None`. - - Returns: - `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the - second element is the number of inference steps. - """ - if timesteps is not None and sigmas is not None: - raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") - if timesteps is not None: - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accepts_timesteps: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" timestep schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - elif sigmas is not None: - accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accept_sigmas: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" sigmas schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - else: - scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) - timesteps = scheduler.timesteps - return timesteps, num_inference_steps - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents -def retrieve_latents( - encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" -): - if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": - return encoder_output.latent_dist.sample(generator) - elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": - return encoder_output.latent_dist.mode() - elif hasattr(encoder_output, "latents"): - return encoder_output.latents - else: - raise AttributeError("Could not access latents of provided encoder_output") - - - -class StableDiffusionXLLoraStep(PipelineBlock): - expected_components = ["text_encoder", "text_encoder_2", "unet"] - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return ( - "Lora step that handles all the lora related tasks: load/unload lora weights into unet and text encoders, manage lora adapters etc" - " See [StableDiffusionXLLoraLoaderMixin](https://huggingface.co/docs/diffusers/api/loaders/lora#diffusers.loaders.StableDiffusionXLLoraLoaderMixin)" - " for more details" - ) - - - @property - def inputs(self) -> List[InputParam]: - return [] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [] - - def __init__(self): - super().__init__() - self.components["text_encoder"] = None - self.components["text_encoder_2"] = None - self.components["unet"] = None - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - raise EnvironmentError("StableDiffusionXLLoraStep is desgined to be used to load lora weights, __call__ is not implemented") - - -class StableDiffusionXLIPAdapterStep(PipelineBlock): - expected_components = ["image_encoder", "feature_extractor", "unet"] - model_name = "stable-diffusion-xl" - - - @property - def description(self) -> str: - return ( - "IP Adapter step that handles all the ip adapter related tasks: Load/unload ip adapter weights into unet, prepare ip adapter image embeddings, etc" - " See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)" - " for more details" - ) - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam( - "ip_adapter_image", - required=True, - type_hint=PipelineImageInput, - description="The image(s) to be used as ip adapter" - ), - InputParam( - "guidance_scale", - default=5.0, - description="Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). Guidance scale is enabled by setting `guidance_scale > 1`." - ), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [ - OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"), - OutputParam("negative_ip_adapter_embeds", type_hint=torch.Tensor, description="Negative IP adapter image embeddings") - ] - - def __init__(self): - super().__init__() - self.components["image_encoder"] = None - self.components["feature_extractor"] = None - self.components["unet"] = None - - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - - data.do_classifier_free_guidance = data.guidance_scale > 1.0 - data.device = pipeline._execution_device - - data.ip_adapter_embeds = pipeline.prepare_ip_adapter_image_embeds( - ip_adapter_image=data.ip_adapter_image, - ip_adapter_image_embeds=None, - device=data.device, - num_images_per_prompt=1, - do_classifier_free_guidance=data.do_classifier_free_guidance, - ) - if data.do_classifier_free_guidance: - data.negative_ip_adapter_embeds = [] - for i, image_embeds in enumerate(data.ip_adapter_embeds): - negative_image_embeds, image_embeds = image_embeds.chunk(2) - data.negative_ip_adapter_embeds.append(negative_image_embeds) - data.ip_adapter_embeds[i] = image_embeds - - self.add_block_state(state, data) - return pipeline, state - - -class StableDiffusionXLTextEncoderStep(PipelineBlock): - expected_components = ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"] - expected_configs = ["force_zeros_for_empty_prompt"] - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return( - "Text Encoder step that generate text_embeddings to guide the image generation" - ) - - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam( - name="prompt", - type_hint=Union[str, List[str]], - description="The prompt or prompts to guide the image generation.", - ), - InputParam( - name="prompt_2", - type_hint=Union[str, List[str]], - description="The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders", - ), - InputParam( - name="negative_prompt", - type_hint=Union[str, List[str]], - description="The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).", - ), - InputParam( - name="negative_prompt_2", - type_hint=Union[str, List[str]], - description="The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders", - ), - InputParam( - name="cross_attention_kwargs", - type_hint=Optional[dict], - description="A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor]", - ), - InputParam( - name="guidance_scale", - type_hint=float, - default=5.0, - description="Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality.", - ), - InputParam( - name="clip_skip", - type_hint=Optional[int], - ), - ] - - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [ - OutputParam("prompt_embeds", type_hint=torch.Tensor, description="text embeddings used to guide the image generation"), - OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="negative text embeddings used to guide the image generation"), - OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, description="pooled text embeddings used to guide the image generation"), - OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="negative pooled text embeddings used to guide the image generation"), - ] - - def __init__(self): - super().__init__() - self.configs["force_zeros_for_empty_prompt"] = True - self.components["text_encoder"] = None - self.components["text_encoder_2"] = None - self.components["tokenizer"] = None - self.components["tokenizer_2"] = None - - def check_inputs(self, pipeline, data): - - if data.prompt is not None and (not isinstance(data.prompt, str) and not isinstance(data.prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(data.prompt)}") - elif data.prompt_2 is not None and (not isinstance(data.prompt_2, str) and not isinstance(data.prompt_2, list)): - raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(data.prompt_2)}") - - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - # Get inputs and intermediates - data = self.get_block_state(state) - self.check_inputs(pipeline, data) - - data.do_classifier_free_guidance = data.guidance_scale > 1.0 - data.device = pipeline._execution_device - - - # Encode input prompt - data.text_encoder_lora_scale = ( - data.cross_attention_kwargs.get("scale", None) if data.cross_attention_kwargs is not None else None - ) - ( - data.prompt_embeds, - data.negative_prompt_embeds, - data.pooled_prompt_embeds, - data.negative_pooled_prompt_embeds, - ) = pipeline.encode_prompt( - data.prompt, - data.prompt_2, - data.device, - 1, - data.do_classifier_free_guidance, - data.negative_prompt, - data.negative_prompt_2, - prompt_embeds=None, - negative_prompt_embeds=None, - pooled_prompt_embeds=None, - negative_pooled_prompt_embeds=None, - lora_scale=data.text_encoder_lora_scale, - clip_skip=data.clip_skip, - ) - # Add outputs - self.add_block_state(state, data) - return pipeline, state - - -class StableDiffusionXLVaeEncoderStep(PipelineBlock): - expected_components = ["vae"] - model_name = "stable-diffusion-xl" - - - @property - def description(self) -> str: - return ( - "Vae Encoder step that encode the input image into a latent representation" - ) - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam( - name="image", - type_hint=PipelineImageInput, - required=True, - description="The image(s) to modify with the pipeline, for img2img or inpainting task. When using for inpainting task, parts of the image will be masked out with `mask_image` and repainted according to `prompt`." - ), - InputParam( - name="generator", - type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], - description="One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)" - "to make generation deterministic." - ), - InputParam( - name="height", - type_hint=Optional[int], - description="The height in pixels of the generated image. This is set to 1024 by default for the best results. " - "Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0]" - "(https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not " - "specifically fine-tuned on low resolutions.", - ), - InputParam( - name="width", - type_hint=Optional[int], - description="The width in pixels of the generated image. This is set to 1024 by default for the best results. " - "Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0]" - "(https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not " - "specifically fine-tuned on low resolutions.", - ), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [ - InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), - InputParam("preprocess_kwargs", type_hint=Optional[dict], description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]")] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation")] - - def __init__(self): - super().__init__() - self.components["vae"] = None - self.auxiliaries["image_processor"] = VaeImageProcessor() - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - data.preprocess_kwargs = data.preprocess_kwargs or {} - data.device = pipeline._execution_device - data.dtype = data.dtype if data.dtype is not None else pipeline.vae.dtype - - data.image = pipeline.image_processor.preprocess(data.image, height=data.height, width=data.width, **data.preprocess_kwargs) - data.image = data.image.to(device=data.device, dtype=data.dtype) - - data.batch_size = data.image.shape[0] - - # if generator is a list, make sure the length of it matches the length of images (both should be batch_size) - if isinstance(data.generator, list) and len(data.generator) != data.batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(data.generator)}, but requested an effective batch" - f" size of {data.batch_size}. Make sure the batch size matches the length of the generators." - ) - - - data.image_latents = pipeline._encode_vae_image(image=data.image, generator=data.generator) - - self.add_block_state(state, data) - - return pipeline, state - - -class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): - expected_components = ["vae"] - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return ( - "Vae encoder step that prepares the image and mask for the inpainting process" - ) - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam( - "height", - type_hint=Optional[int], - description="The height in pixels of the generated image. This is set to 1024 by default for the best results. " - "Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0]" - "(https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not " - "specifically fine-tuned on low resolutions.", - ), - InputParam( - "width", - type_hint=Optional[int], - description="The width in pixels of the generated image. This is set to 1024 by default for the best results. " - "Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0]" - "(https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not " - "specifically fine-tuned on low resolutions.", - ), - InputParam( - "generator", - type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], - description="One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) " - "to make generation deterministic." - ), - InputParam( - "image", - required=True, - type_hint=PipelineImageInput, - description="The image(s) to modify with the pipeline, for img2img or inpainting task. When using for inpainting task, parts of the image will be masked out with `mask_image` and repainted according to `prompt`." - ), - InputParam( - "mask_image", - required=True, - type_hint=PipelineImageInput, - description="`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be " - "repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted " - "to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) " - "instead of 3, so the expected shape would be `(B, H, W, 1)`." - ), - InputParam( - "padding_mask_crop", - type_hint=Optional[Tuple[int, int]], - description="The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to " - "image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region " - "with the same aspect ratio of the image and contains all masked area, and then expand that area based " - "on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before " - "resizing to the original image size for inpainting. This is useful when the masked area is small while " - "the image is large and contain information irrelevant for inpainting, such as background." - ), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs")] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representation of the input image"), - OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"), - OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)"), - OutputParam("crops_coords", type_hint=Optional[Tuple[int, int]], description="The crop coordinates to use for the preprocess/postprocess of the image and mask")] - - def __init__(self): - super().__init__() - self.auxiliaries["image_processor"] = VaeImageProcessor() - self.auxiliaries["mask_processor"] = VaeImageProcessor(do_normalize=False, do_binarize=True, do_convert_grayscale=True) - self.components["vae"] = None - - @torch.no_grad() - def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - - data = self.get_block_state(state) - - data.dtype = data.dtype if data.dtype is not None else pipeline.vae.dtype - data.device = pipeline._execution_device - - if data.padding_mask_crop is not None: - data.crops_coords = pipeline.mask_processor.get_crop_region(data.mask_image, data.width, data.height, pad=data.padding_mask_crop) - data.resize_mode = "fill" - else: - data.crops_coords = None - data.resize_mode = "default" - - data.image = pipeline.image_processor.preprocess(data.image, height=data.height, width=data.width, crops_coords=data.crops_coords, resize_mode=data.resize_mode) - data.image = data.image.to(dtype=torch.float32) - - data.mask = pipeline.mask_processor.preprocess(data.mask_image, height=data.height, width=data.width, resize_mode=data.resize_mode, crops_coords=data.crops_coords) - data.masked_image = data.image * (data.mask < 0.5) - - data.batch_size = data.image.shape[0] - data.image = data.image.to(device=data.device, dtype=data.dtype) - data.image_latents = pipeline._encode_vae_image(image=data.image, generator=data.generator) - - # 7. Prepare mask latent variables - data.mask, data.masked_image_latents = pipeline.prepare_mask_latents( - data.mask, - data.masked_image, - data.batch_size, - data.height, - data.width, - data.dtype, - data.device, - data.generator, - ) - - self.add_block_state(state, data) - - - return pipeline, state - - -class StableDiffusionXLInputStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return ( - "Input processing step that:\n" - " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" - " 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`\n\n" - "All input tensors are expected to have either batch_size=1 or match the batch_size\n" - "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n" - "have a final batch_size of batch_size * num_images_per_prompt." - ) - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam( - name="num_images_per_prompt", - type_hint=int, - default=1, - description="The number of images to generate per prompt.", - ), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam("prompt_embeds", required=True, type_hint=torch.Tensor, description="Pre-generated text embeddings. Can be generated from text_encoder step."), - InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Pre-generated negative text embeddings. Can be generated from text_encoder step."), - InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor, description="Pre-generated pooled text embeddings. Can be generated from text_encoder step."), - InputParam("negative_pooled_prompt_embeds", description="Pre-generated negative pooled text embeddings. Can be generated from text_encoder step."), - InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Pre-generated image embeddings for IP-Adapter. Can be generated from ip_adapter step."), - InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Pre-generated negative image embeddings for IP-Adapter. Can be generated from ip_adapter step."), - ] - - @property - def intermediates_outputs(self) -> List[str]: - return [ - OutputParam("batch_size", type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt"), - OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs (determined by `prompt_embeds`)"), - OutputParam("prompt_embeds", type_hint=torch.Tensor, description="text embeddings used to guide the image generation"), - OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="negative text embeddings used to guide the image generation"), - OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, description="pooled text embeddings used to guide the image generation"), - OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="negative pooled text embeddings used to guide the image generation"), - OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="image embeddings for IP-Adapter"), - OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="negative image embeddings for IP-Adapter"), - ] - - def check_inputs(self, pipeline, data): - - if data.prompt_embeds is not None and data.negative_prompt_embeds is not None: - if data.prompt_embeds.shape != data.negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {data.prompt_embeds.shape} != `negative_prompt_embeds`" - f" {data.negative_prompt_embeds.shape}." - ) - - if data.prompt_embeds is not None and data.pooled_prompt_embeds is None: - raise ValueError( - "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." - ) - - if data.negative_prompt_embeds is not None and data.negative_pooled_prompt_embeds is None: - raise ValueError( - "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." - ) - - if data.ip_adapter_embeds is not None and not isinstance(data.ip_adapter_embeds, list): - raise ValueError("`ip_adapter_embeds` must be a list") - - if data.negative_ip_adapter_embeds is not None and not isinstance(data.negative_ip_adapter_embeds, list): - raise ValueError("`negative_ip_adapter_embeds` must be a list") - - if data.ip_adapter_embeds is not None and data.negative_ip_adapter_embeds is not None: - for i, ip_adapter_embed in enumerate(data.ip_adapter_embeds): - if ip_adapter_embed.shape != data.negative_ip_adapter_embeds[i].shape: - raise ValueError( - "`ip_adapter_embeds` and `negative_ip_adapter_embeds` must have the same shape when passed directly, but" - f" got: `ip_adapter_embeds` {ip_adapter_embed.shape} != `negative_ip_adapter_embeds`" - f" {data.negative_ip_adapter_embeds[i].shape}." - ) - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - self.check_inputs(pipeline, data) - - data.batch_size = data.prompt_embeds.shape[0] - data.dtype = data.prompt_embeds.dtype - - _, seq_len, _ = data.prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - data.prompt_embeds = data.prompt_embeds.repeat(1, data.num_images_per_prompt, 1) - data.prompt_embeds = data.prompt_embeds.view(data.batch_size * data.num_images_per_prompt, seq_len, -1) - - if data.negative_prompt_embeds is not None: - _, seq_len, _ = data.negative_prompt_embeds.shape - data.negative_prompt_embeds = data.negative_prompt_embeds.repeat(1, data.num_images_per_prompt, 1) - data.negative_prompt_embeds = data.negative_prompt_embeds.view(data.batch_size * data.num_images_per_prompt, seq_len, -1) - - data.pooled_prompt_embeds = data.pooled_prompt_embeds.repeat(1, data.num_images_per_prompt, 1) - data.pooled_prompt_embeds = data.pooled_prompt_embeds.view(data.batch_size * data.num_images_per_prompt, -1) - - if data.negative_pooled_prompt_embeds is not None: - data.negative_pooled_prompt_embeds = data.negative_pooled_prompt_embeds.repeat(1, data.num_images_per_prompt, 1) - data.negative_pooled_prompt_embeds = data.negative_pooled_prompt_embeds.view(data.batch_size * data.num_images_per_prompt, -1) - - if data.ip_adapter_embeds is not None: - for i, ip_adapter_embed in enumerate(data.ip_adapter_embeds): - data.ip_adapter_embeds[i] = torch.cat([ip_adapter_embed] * data.num_images_per_prompt, dim=0) - - if data.negative_ip_adapter_embeds is not None: - for i, negative_ip_adapter_embed in enumerate(data.negative_ip_adapter_embeds): - data.negative_ip_adapter_embeds[i] = torch.cat([negative_ip_adapter_embed] * data.num_images_per_prompt, dim=0) - - self.add_block_state(state, data) - - return pipeline, state - - -class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): - expected_components = ["scheduler"] - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return ( - "Step that sets the timesteps for the scheduler and determines the initial noise level (latent_timestep) for image-to-image/inpainting generation.\n" + \ - "The latent_timestep is calculated from the `strength` parameter - higher strength means starting from a noisier version of the input image." - ) - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam( - "num_inference_steps", - default=50, - type_hint=int, - description="The number of denoising steps. More denoising steps usually lead to a higher quality image at the" - " expense of slower inference." - ), - InputParam( - "timesteps", - type_hint=Optional[torch.Tensor], - description="Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order." - ), - InputParam( - "sigmas", - type_hint=Optional[torch.Tensor], - description="Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used." - ), - InputParam( - "denoising_end", - type_hint=Optional[float], - description="When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be completed before it is intentionally prematurely terminated. As a result, the returned sample will still retain a substantial amount of noise as determined by the discrete timesteps selected by the scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a 'Mixture of Denoisers' multi-pipeline setup." - ), - InputParam( - "strength", - default=0.3, - type_hint=float, - description="Conceptually, indicates how much to transform the reference `image` (the masked portion of image for inpainting). Must be between 0 and 1. `image` " - "will be used as a starting point, adding more noise to it the larger the `strength`. The number of " - "denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will " - "be maximum and the denoising process will run for the full number of iterations specified in " - "`num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of " - "`denoising_start` being declared as an integer, the value of `strength` will be ignored." - ), - InputParam( - "denoising_start", - type_hint=Optional[float], - description="The denoising start value to use for the scheduler. Determines the starting point of the denoising process." - ), - InputParam( - "num_images_per_prompt", - default=1, - type_hint=int, - description="The number of images to generate per prompt. Defaults to 1." - ), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt"), - ] - - @property - def intermediates_outputs(self) -> List[str]: - return [ - OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), - OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time"), - OutputParam("latent_timestep", type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image generation") - ] - - def __init__(self): - super().__init__() - self.components["scheduler"] = None - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - - data.device = pipeline._execution_device - - data.timesteps, data.num_inference_steps = retrieve_timesteps( - pipeline.scheduler, data.num_inference_steps, data.device, data.timesteps, data.sigmas - ) - - def denoising_value_valid(dnv): - return isinstance(dnv, float) and 0 < dnv < 1 - - data.timesteps, data.num_inference_steps = pipeline.get_timesteps( - data.num_inference_steps, - data.strength, - data.device, - denoising_start=data.denoising_start if denoising_value_valid(data.denoising_start) else None, - ) - data.latent_timestep = data.timesteps[:1].repeat(data.batch_size * data.num_images_per_prompt) - - if data.denoising_end is not None and isinstance(data.denoising_end, float) and data.denoising_end > 0 and data.denoising_end < 1: - data.discrete_timestep_cutoff = int( - round( - pipeline.scheduler.config.num_train_timesteps - - (data.denoising_end * pipeline.scheduler.config.num_train_timesteps) - ) - ) - data.num_inference_steps = len(list(filter(lambda ts: ts >= data.discrete_timestep_cutoff, data.timesteps))) - data.timesteps = data.timesteps[:data.num_inference_steps] - - self.add_block_state(state, data) - - return pipeline, state - - -class StableDiffusionXLSetTimestepsStep(PipelineBlock): - expected_components = ["scheduler"] - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return ( - "Step that sets the scheduler's timesteps for inference" - ) - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam( - "num_inference_steps", - default=50, - type_hint=int, - description="The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference." - ), - InputParam( - "timesteps", - type_hint=Optional[torch.Tensor], - description="Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order." - ), - InputParam( - "sigmas", - type_hint=Optional[torch.Tensor], - description="Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used." - ), - InputParam( - "denoising_end", - type_hint=Optional[float], - description="When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be completed before it is intentionally prematurely terminated. As a result, the returned sample will still retain a substantial amount of noise as determined by the discrete timesteps selected by the scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a 'Mixture of Denoisers' multi-pipeline setup." - ), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), - OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time")] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [] - - def __init__(self): - super().__init__() - self.components["scheduler"] = None - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - - data.device = pipeline._execution_device - - data.timesteps, data.num_inference_steps = retrieve_timesteps( - pipeline.scheduler, data.num_inference_steps, data.device, data.timesteps, data.sigmas - ) - - if data.denoising_end is not None and isinstance(data.denoising_end, float) and data.denoising_end > 0 and data.denoising_end < 1: - data.discrete_timestep_cutoff = int( - round( - pipeline.scheduler.config.num_train_timesteps - - (data.denoising_end * pipeline.scheduler.config.num_train_timesteps) - ) - ) - data.num_inference_steps = len(list(filter(lambda ts: ts >= data.discrete_timestep_cutoff, data.timesteps))) - data.timesteps = data.timesteps[:data.num_inference_steps] - - self.add_block_state(state, data) - return pipeline, state - - -class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): - expected_components = ["scheduler"] - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return ( - "Step that prepares the latents for the inpainting process" - ) - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam( - "generator", - type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], - description="One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) " - "to make generation deterministic."), - InputParam( - "latents", - type_hint=Optional[torch.Tensor], - description="Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`." - ), - InputParam( - "num_images_per_prompt", - default=1, - type_hint=int, - description="The number of images to generate per prompt" - ), - InputParam( - "denoising_start", - type_hint=Optional[float], - description="When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be bypassed before it is initiated. The initial part of the denoising process is skipped and it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, strength will be ignored. Useful for 'Mixture of Denoisers' multi-pipeline setups." - ), - InputParam( - "strength", - default=0.9999, - type_hint=float, - description="Conceptually, indicates how much to transform the reference `image` (the masked portion of image for inpainting). Must be between 0 and 1. `image` " - "will be used as a starting point, adding more noise to it the larger the `strength`. The number of " - "denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will " - "be maximum and the denoising process will run for the full number of iterations specified in " - "`num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of " - "`denoising_start` being declared as an integer, the value of `strength` will be ignored." - ), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), - InputParam( - "latent_timestep", - required=True, - type_hint=torch.Tensor, - description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step." - ), - InputParam( - "image_latents", - required=True, - type_hint=torch.Tensor, - description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step." - ), - InputParam( - "mask", - required=True, - type_hint=torch.Tensor, - description="The mask for the inpainting generation. Can be generated in vae_encode step." - ), - InputParam( - "masked_image_latents", - type_hint=torch.Tensor, - description="The masked image latents for the inpainting generation (only for inpainting-specific unet). Can be generated in vae_encode step." - ), - InputParam( - "dtype", - type_hint=torch.dtype, - description="The dtype of the model inputs" - ) - ] - - @property - def intermediates_outputs(self) -> List[str]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"), - OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for inpainting generation"), - OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting generation (only for inpainting-specific unet)"), - OutputParam("noise", type_hint=torch.Tensor, description="The noise added to the image latents, used for inpainting generation")] - - def __init__(self): - super().__init__() - self.components["scheduler"] = None - - @torch.no_grad() - def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - - data.dtype = data.dtype if data.dtype is not None else pipeline.vae.dtype - data.device = pipeline._execution_device - - data.is_strength_max = data.strength == 1.0 - - # for non-inpainting specific unet, we do not need masked_image_latents - if hasattr(pipeline,"unet") and pipeline.unet is not None: - if pipeline.unet.config.in_channels == 4: - data.masked_image_latents = None - - data.add_noise = True if data.denoising_start is None else False - - data.height = data.image_latents.shape[-2] * pipeline.vae_scale_factor - data.width = data.image_latents.shape[-1] * pipeline.vae_scale_factor - - data.latents, data.noise = pipeline.prepare_latents_inpaint( - data.batch_size * data.num_images_per_prompt, - pipeline.num_channels_latents, - data.height, - data.width, - data.dtype, - data.device, - data.generator, - data.latents, - image=data.image_latents, - timestep=data.latent_timestep, - is_strength_max=data.is_strength_max, - add_noise=data.add_noise, - return_noise=True, - return_image_latents=False, - ) - - # 7. Prepare mask latent variables - data.mask, data.masked_image_latents = pipeline.prepare_mask_latents( - data.mask, - data.masked_image_latents, - data.batch_size * data.num_images_per_prompt, - data.height, - data.width, - data.dtype, - data.device, - data.generator, - ) - - self.add_block_state(state, data) - - return pipeline, state - - -class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): - expected_components = ["vae", "scheduler"] - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return ( - "Step that prepares the latents for the image-to-image generation process" - ) - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam( - "generator", - type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], - description="One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) " - "to make generation deterministic." - ), - InputParam( - "latents", - type_hint=Optional[torch.Tensor], - description="Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`." - ), - InputParam( - "num_images_per_prompt", - default=1, - type_hint=int, - description="The number of images to generate per prompt" - ), - InputParam( - "denoising_start", - type_hint=Optional[float], - description="When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be bypassed before it is initiated. The initial part of the denoising process is skipped and it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, strength will be ignored. Useful for 'Mixture of Denoisers' multi-pipeline setups." - ), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [ - InputParam("latent_timestep", required=True, type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step."), - InputParam("image_latents", required=True, type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step."), - InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."), - InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs")] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process")] - - def __init__(self): - super().__init__() - self.components["scheduler"] = None - - @torch.no_grad() - def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - - data.dtype = data.dtype if data.dtype is not None else pipeline.vae.dtype - data.device = pipeline._execution_device - data.add_noise = True if data.denoising_start is None else False - if data.latents is None: - data.latents = pipeline.prepare_latents_img2img( - data.image_latents, - data.latent_timestep, - data.batch_size, - data.num_images_per_prompt, - data.dtype, - data.device, - data.generator, - data.add_noise, - ) - - self.add_block_state(state, data) - - return pipeline, state - - -class StableDiffusionXLPrepareLatentsStep(PipelineBlock): - expected_components = ["scheduler"] - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return ( - "Prepare latents step that prepares the latents for the text-to-image generation process" - ) - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam( - "height", - type_hint=Optional[int], - description="The height in pixels of the generated image. This is set to 1024 by default for the best results. " - "Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0]" - "(https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not " - "specifically fine-tuned on low resolutions."), - InputParam( - "width", - type_hint=Optional[int], - description="The width in pixels of the generated image. This is set to 1024 by default for the best results. " - "Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0]" - "(https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not " - "specifically fine-tuned on low resolutions."), - InputParam( - "generator", - type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], - description="One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) " - "to make generation deterministic." - ), - InputParam( - "latents", - type_hint=Optional[torch.Tensor], - description="Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`." - ), - InputParam( - "num_images_per_prompt", - default=1, - type_hint=int, - description="The number of images to generate per prompt" - ), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [ - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), - InputParam( - "dtype", - type_hint=torch.dtype, - description="The dtype of the model inputs" - ) - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [ - OutputParam( - "latents", - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process" - ) - ] - - def __init__(self): - super().__init__() - self.components["scheduler"] = None - - @staticmethod - def check_inputs(pipeline, data): - if ( - data.height is not None - and data.height % pipeline.vae_scale_factor != 0 - or data.width is not None - and data.width % pipeline.vae_scale_factor != 0 - ): - raise ValueError( - f"`height` and `width` have to be divisible by {pipeline.vae_scale_factor} but are {data.height} and {data.width}." - ) - - @torch.no_grad() - def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - - if data.dtype is None: - data.dtype = pipeline.vae.dtype - - data.device = pipeline._execution_device - - self.check_inputs(pipeline, data) - - data.height = data.height or pipeline.default_sample_size * pipeline.vae_scale_factor - data.width = data.width or pipeline.default_sample_size * pipeline.vae_scale_factor - data.num_channels_latents = pipeline.num_channels_latents - data.latents = pipeline.prepare_latents( - data.batch_size * data.num_images_per_prompt, - data.num_channels_latents, - data.height, - data.width, - data.dtype, - data.device, - data.generator, - data.latents, - ) - - self.add_block_state(state, data) - - return pipeline, state - - -class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): - expected_configs = ["requires_aesthetics_score"] - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return ( - "Step that prepares the additional conditioning for the image-to-image/inpainting generation process" - ) - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam( - "original_size", - type_hint=Optional[Tuple[int]], - description="If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. " - "`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as " - "explained in section 2.2 of https://huggingface.co/papers/2307.01952" - ), - InputParam( - "target_size", - type_hint=Optional[Tuple[int]], - description="For most cases, `target_size` should be set to the desired height and width of the generated image. If " - "not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in " - "section 2.2 of https://huggingface.co/papers/2307.01952" - ), - InputParam( - "negative_original_size", - type_hint=Optional[Tuple[int]], - description="To negatively condition the generation process based on a specific image resolution. Part of SDXL's " - "micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952" - ), - InputParam( - "negative_target_size", - type_hint=Optional[Tuple[int]], - description="To negatively condition the generation process based on a target image resolution. It should be as same " - "as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of " - "https://huggingface.co/papers/2307.01952" - ), - InputParam( - "crops_coords_top_left", - default=(0, 0), - type_hint=Tuple[int], - description="`crops_coords_top_left` can be used to generate an image that appears to be \"cropped\" from the position " - "`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning" - ), - InputParam( - "negative_crops_coords_top_left", - default=(0, 0), - type_hint=Tuple[int], - description="To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's " - "micro-conditioning" - ), - InputParam( - "num_images_per_prompt", - default=1, - type_hint=int, - description="The number of images to generate per prompt." - ), - InputParam( - "guidance_scale", - default=5.0, - type_hint=float, - description="Guidance scale as defined in Classifier-Free Diffusion Guidance. `guidance_scale` is defined as `w` of equation 2. " - "Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, " - "usually at the expense of lower image quality." - ), - InputParam( - "aesthetic_score", - default=6.0, - type_hint=float, - description="Used to simulate an aesthetic score of the generated image by influencing the positive text condition. " - "Part of SDXL's micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952" - ), - InputParam( - "negative_aesthetic_score", - default=2.0, - type_hint=float, - description="Part of SDXL's micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952. " - "Can be used to simulate an aesthetic score of the generated image by influencing the negative text condition." - ), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [ - InputParam("latents", required=True, type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."), - InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor, description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step."), - InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("add_time_ids", type_hint=torch.Tensor, description="The time ids to condition the denoising process"), - OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="The negative time ids to condition the denoising process"), - OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] - - def __init__(self): - super().__init__() - self.configs["requires_aesthetics_score"] = False - - @torch.no_grad() - def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - data.device = pipeline._execution_device - - data.vae_scale_factor = pipeline.vae_scale_factor - - data.height, data.width = data.latents.shape[-2:] - data.height = data.height * data.vae_scale_factor - data.width = data.width * data.vae_scale_factor - - data.original_size = data.original_size or (data.height, data.width) - data.target_size = data.target_size or (data.height, data.width) - - data.text_encoder_projection_dim = int(data.pooled_prompt_embeds.shape[-1]) - - if data.negative_original_size is None: - data.negative_original_size = data.original_size - if data.negative_target_size is None: - data.negative_target_size = data.target_size - - data.add_time_ids, data.negative_add_time_ids = pipeline._get_add_time_ids_img2img( - data.original_size, - data.crops_coords_top_left, - data.target_size, - data.aesthetic_score, - data.negative_aesthetic_score, - data.negative_original_size, - data.negative_crops_coords_top_left, - data.negative_target_size, - dtype=data.pooled_prompt_embeds.dtype, - text_encoder_projection_dim=data.text_encoder_projection_dim, - ) - data.add_time_ids = data.add_time_ids.repeat(data.batch_size * data.num_images_per_prompt, 1).to(device=data.device) - data.negative_add_time_ids = data.negative_add_time_ids.repeat(data.batch_size * data.num_images_per_prompt, 1).to(device=data.device) - - # Optionally get Guidance Scale Embedding for LCM - data.timestep_cond = None - if ( - hasattr(pipeline, "unet") - and pipeline.unet is not None - and pipeline.unet.config.time_cond_proj_dim is not None - ): - data.guidance_scale_tensor = torch.tensor(data.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt) - data.timestep_cond = pipeline.get_guidance_scale_embedding( - data.guidance_scale_tensor, embedding_dim=pipeline.unet.config.time_cond_proj_dim - ).to(device=data.device, dtype=data.latents.dtype) - - self.add_block_state(state, data) - return pipeline, state - - -class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return ( - "Step that prepares the additional conditioning for the text-to-image generation process" - ) - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam( - "original_size", - type_hint=Tuple[int, int], - default=(1024, 1024), - description="The original size (height, width) of the image that conditions the generation process. If different from target_size, the image will appear to be down- or upsampled. Part of SDXL's micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952" - ), - InputParam( - "target_size", - type_hint=Tuple[int, int], - default=(1024, 1024), - description="The target size (height, width) of the generated image. For most cases, this should be set to the desired output dimensions. Part of SDXL's micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952" - ), - InputParam( - "negative_original_size", - type_hint=Tuple[int, int], - default=(1024, 1024), - description="The negative original size to condition against during generation. Part of SDXL's micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952. See: https://github.com/huggingface/diffusers/issues/4208" - ), - InputParam( - "negative_target_size", - type_hint=Tuple[int, int], - default=(1024, 1024), - description="The negative target size to condition against during generation. Should typically match target_size. Part of SDXL's micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952. See: https://github.com/huggingface/diffusers/issues/4208" - ), - InputParam( - "crops_coords_top_left", - default=(0, 0), - type_hint=Tuple[int, int], - description="The top-left coordinates (x, y) used to condition the generation process. Setting this to (0, 0) typically produces well-centered images. Part of SDXL's micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952" - ), - InputParam( - "negative_crops_coords_top_left", - default=(0, 0), - type_hint=Tuple[int, int], - description="The top-left coordinates (x, y) used to negatively condition the generation process. Part of SDXL's micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952. For more information, see: https://github.com/huggingface/diffusers/issues/4208" - ), - InputParam( - "num_images_per_prompt", - default=1, - type_hint=int, - description="The number of images to generate per prompt" - ), - InputParam( - "guidance_scale", - default=5.0, - type_hint=float, - description="Guidance scale as defined in Classifier-Free Diffusion Guidance. `guidance_scale` is defined as `w` of equation 2. " - "Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, " - "usually at the expense of lower image quality."), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [ - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." - ), - InputParam( - "pooled_prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step." - ), - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("add_time_ids", type_hint=torch.Tensor, description="The time ids to condition the denoising process"), - OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="The negative time ids to condition the denoising process"), - OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] - - @torch.no_grad() - def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - data.device = pipeline._execution_device - - data.height, data.width = data.latents.shape[-2:] - data.height = data.height * pipeline.vae_scale_factor - data.width = data.width * pipeline.vae_scale_factor - - data.original_size = data.original_size or (data.height, data.width) - data.target_size = data.target_size or (data.height, data.width) - - data.text_encoder_projection_dim = int(data.pooled_prompt_embeds.shape[-1]) - - data.add_time_ids = pipeline._get_add_time_ids( - data.original_size, - data.crops_coords_top_left, - data.target_size, - data.pooled_prompt_embeds.dtype, - text_encoder_projection_dim=data.text_encoder_projection_dim, - ) - if data.negative_original_size is not None and data.negative_target_size is not None: - data.negative_add_time_ids = pipeline._get_add_time_ids( - data.negative_original_size, - data.negative_crops_coords_top_left, - data.negative_target_size, - data.pooled_prompt_embeds.dtype, - text_encoder_projection_dim=data.text_encoder_projection_dim, - ) - else: - data.negative_add_time_ids = data.add_time_ids - - data.add_time_ids = data.add_time_ids.repeat(data.batch_size * data.num_images_per_prompt, 1).to(device=data.device) - data.negative_add_time_ids = data.negative_add_time_ids.repeat(data.batch_size * data.num_images_per_prompt, 1).to(device=data.device) - - # Optionally get Guidance Scale Embedding for LCM - data.timestep_cond = None - if ( - hasattr(pipeline, "unet") - and pipeline.unet is not None - and pipeline.unet.config.time_cond_proj_dim is not None - ): - data.guidance_scale_tensor = torch.tensor(data.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt) - data.timestep_cond = pipeline.get_guidance_scale_embedding( - data.guidance_scale_tensor, embedding_dim=pipeline.unet.config.time_cond_proj_dim - ).to(device=data.device, dtype=data.latents.dtype) - - self.add_block_state(state, data) - return pipeline, state - - -class StableDiffusionXLDenoiseStep(PipelineBlock): - expected_components = ["unet", "scheduler", "guider"] - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return ( - "Step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process" - ) - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam( - "guidance_scale", - type_hint=float, - default=5.0, - description="Guidance scale as defined in Classifier-Free Diffusion Guidance. Higher values encourage images closely linked to the text prompt, potentially at the expense of image quality. Enabled when > 1." - ), - InputParam( - "guidance_rescale", - type_hint=float, - default=0.0, - description="Guidance rescale factor (φ) to fix overexposure when using zero terminal SNR, as proposed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed'." - ), - InputParam( - "cross_attention_kwargs", - type_hint=Optional[Dict[str, Any]], - default=None, - description="Optional kwargs dictionary passed to the AttentionProcessor." - ), - InputParam( - "generator", - type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], - description="One or a list of torch generator(s) to make generation deterministic." - ), - InputParam( - "eta", - type_hint=float, - default=0.0, - description="Parameter η in the DDIM paper. Only applies to DDIMScheduler, ignored for others." - ), - InputParam( - "guider_kwargs", - type_hint=Optional[Dict[str, Any]], - default=None, - description="Optional kwargs dictionary passed to the Guider." - ), - InputParam( - "num_images_per_prompt", - type_hint=int, - default=1, - description="The number of images to generate per prompt." - ), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." - ), - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), - InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "pooled_prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "negative_pooled_prompt_embeds", - type_hint=Optional[torch.Tensor], - description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " - ), - InputParam( - "add_time_ids", - required=True, - type_hint=torch.Tensor, - description="The time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." - ), - InputParam( - "negative_add_time_ids", - type_hint=Optional[torch.Tensor], - description="The negative time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." - ), - InputParam( - "prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "negative_prompt_embeds", - type_hint=Optional[torch.Tensor], - description="The negative prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " - ), - InputParam( - "timestep_cond", - type_hint=Optional[torch.Tensor], - description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." - ), - InputParam( - "mask", - type_hint=Optional[torch.Tensor], - description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "masked_image_latents", - type_hint=Optional[torch.Tensor], - description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "noise", - type_hint=Optional[torch.Tensor], - description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." - ), - InputParam( - "image_latents", - type_hint=Optional[torch.Tensor], - description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "ip_adapter_embeds", - type_hint=Optional[torch.Tensor], - description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." - ), - InputParam( - "negative_ip_adapter_embeds", - type_hint=Optional[torch.Tensor], - description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." - ), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - - def __init__(self): - super().__init__() - self.components["guider"] = CFGGuider() - self.components["scheduler"] = None - self.components["unet"] = None - - def check_inputs(self, pipeline, data): - - num_channels_unet = pipeline.unet.config.in_channels - if num_channels_unet == 9: - # default case for runwayml/stable-diffusion-inpainting - if data.mask is None or data.masked_image_latents is None: - raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") - num_channels_latents = data.latents.shape[1] - num_channels_mask = data.mask.shape[1] - num_channels_masked_image = data.masked_image_latents.shape[1] - if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: - raise ValueError( - f"Incorrect configuration settings! The config of `pipeline.unet`: {pipeline.unet.config} expects" - f" {pipeline.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" - f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" - " `pipeline.unet` or your `mask_image` or `image` input." - ) - - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - - data = self.get_block_state(state) - self.check_inputs(pipeline, data) - - data.num_channels_unet = pipeline.unet.config.in_channels - data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False - - # adding default guider arguments: do_classifier_free_guidance, guidance_scale, guidance_rescale - data.guider_kwargs = data.guider_kwargs or {} - data.guider_kwargs = { - **data.guider_kwargs, - "disable_guidance": data.disable_guidance, - "guidance_scale": data.guidance_scale, - "guidance_rescale": data.guidance_rescale, - "batch_size": data.batch_size * data.num_images_per_prompt, - } - - pipeline.guider.set_guider(pipeline, data.guider_kwargs) - # Prepare conditional inputs using the guider - data.prompt_embeds = pipeline.guider.prepare_input( - data.prompt_embeds, - data.negative_prompt_embeds, - ) - data.add_time_ids = pipeline.guider.prepare_input( - data.add_time_ids, - data.negative_add_time_ids, - ) - data.pooled_prompt_embeds = pipeline.guider.prepare_input( - data.pooled_prompt_embeds, - data.negative_pooled_prompt_embeds, - ) - - if data.num_channels_unet == 9: - data.mask = pipeline.guider.prepare_input(data.mask, data.mask) - data.masked_image_latents = pipeline.guider.prepare_input(data.masked_image_latents, data.masked_image_latents) - - data.added_cond_kwargs = { - "text_embeds": data.pooled_prompt_embeds, - "time_ids": data.add_time_ids, - } - - if data.ip_adapter_embeds is not None: - data.ip_adapter_embeds = pipeline.guider.prepare_input(data.ip_adapter_embeds, data.negative_ip_adapter_embeds) - data.added_cond_kwargs["image_embeds"] = data.ip_adapter_embeds - - # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - data.extra_step_kwargs = pipeline.prepare_extra_step_kwargs(data.generator, data.eta) - data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0) - - with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: - for i, t in enumerate(data.timesteps): - # expand the latents if we are doing classifier free guidance - data.latent_model_input = pipeline.guider.prepare_input(data.latents, data.latents) - data.latent_model_input = pipeline.scheduler.scale_model_input(data.latent_model_input, t) - - # inpainting - if data.num_channels_unet == 9: - data.latent_model_input = torch.cat([data.latent_model_input, data.mask, data.masked_image_latents], dim=1) - - # predict the noise residual - data.noise_pred = pipeline.unet( - data.latent_model_input, - t, - encoder_hidden_states=data.prompt_embeds, - timestep_cond=data.timestep_cond, - cross_attention_kwargs=data.cross_attention_kwargs, - added_cond_kwargs=data.added_cond_kwargs, - return_dict=False, - )[0] - # perform guidance - data.noise_pred = pipeline.guider.apply_guidance( - data.noise_pred, - timestep=t, - latents=data.latents, - ) - # compute the previous noisy sample x_t -> x_t-1 - data.latents_dtype = data.latents.dtype - data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0] - if data.latents.dtype != data.latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - data.latents = data.latents.to(data.latents_dtype) - - if data.num_channels_unet == 4 and data.mask is not None and data.image_latents is not None: - data.init_latents_proper = data.image_latents - if i < len(data.timesteps) - 1: - data.noise_timestep = data.timesteps[i + 1] - data.init_latents_proper = pipeline.scheduler.add_noise( - data.init_latents_proper, data.noise, torch.tensor([data.noise_timestep]) - ) - - data.latents = (1 - data.mask) * data.init_latents_proper + data.mask * data.latents - - if i == len(data.timesteps) - 1 or ((i + 1) > data.num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): - progress_bar.update() - - pipeline.guider.reset_guider(pipeline) - self.add_block_state(state, data) - - return pipeline, state - - -class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): - expected_components = ["unet", "controlnet", "scheduler", "guider", "controlnet_guider"] - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam( - "control_image", - required=True, - type_hint=PipelineImageInput, - description="The ControlNet input condition to provide guidance to the unet for generation. If passed as torch.Tensor, it is used as-is. PIL.Image.Image inputs are accepted and default to image dimensions. For multiple ControlNets, pass images as a list for proper batching." - ), - InputParam( - "control_guidance_start", - default=0.0, - type_hint=Union[float, List[float]], - description="The percentage of total steps at which the ControlNet starts applying." - ), - InputParam( - "control_guidance_end", - default=1.0, - type_hint=Union[float, List[float]], - description="The percentage of total steps at which the ControlNet stops applying." - ), - InputParam( - "controlnet_conditioning_scale", - default=1.0, - type_hint=Union[float, List[float]], - description="Scale factor for ControlNet outputs before adding to unet residual. For multiple ControlNets, can be set as a list of scales." - ), - InputParam( - "guess_mode", - default=False, - type_hint=bool, - description="Enables ControlNet encoder to recognize input image content without prompts. Recommended guidance_scale: 3.0-5.0." - ), - InputParam( - "num_images_per_prompt", - default=1, - type_hint=int, - description="The number of images to generate per prompt." - ), - InputParam( - "guidance_scale", - default=5.0, - type_hint=float, - description="Guidance scale as defined in Classifier-Free Diffusion Guidance. Higher values encourage images closely linked to the text prompt, potentially at the expense of image quality. Enabled when > 1." - ), - InputParam( - "guidance_rescale", - default=0.0, - type_hint=float, - description="Guidance rescale factor (φ) to fix overexposure when using zero terminal SNR, as proposed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed'." - ), - InputParam( - "cross_attention_kwargs", - default=None, - type_hint=Optional[Dict[str, Any]], - description="Optional kwargs dictionary passed to the AttentionProcessor." - ), - InputParam( - "generator", - default=None, - type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], - description="One or a list of torch generator(s) to make generation deterministic." - ), - InputParam( - "eta", - default=0.0, - type_hint=float, - description="Parameter η in the DDIM paper. Only applies to DDIMScheduler, ignored for others." - ), - InputParam( - "guider_kwargs", - default=None, - type_hint=Optional[Dict[str, Any]], - description="Optional kwargs dictionary passed to the Guider." - ), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." - ), - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), - InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "negative_prompt_embeds", - type_hint=Optional[torch.Tensor], - description="The negative prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "add_time_ids", - required=True, - type_hint=torch.Tensor, - description="The time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." - ), - InputParam( - "negative_add_time_ids", - type_hint=Optional[torch.Tensor], - description="The negative time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." - ), - InputParam( - "pooled_prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The pooled prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "negative_pooled_prompt_embeds", - type_hint=Optional[torch.Tensor], - description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "timestep_cond", - type_hint=Optional[torch.Tensor], - description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" - ), - InputParam( - "mask", - type_hint=Optional[torch.Tensor], - description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "masked_image_latents", - type_hint=Optional[torch.Tensor], - description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "noise", - type_hint=Optional[torch.Tensor], - description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." - ), - InputParam( - "image_latents", - type_hint=Optional[torch.Tensor], - description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "crops_coords", - type_hint=Optional[Tuple[int]], - description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." - ), - InputParam( - "ip_adapter_embeds", - type_hint=Optional[torch.Tensor], - description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." - ), - InputParam( - "negative_ip_adapter_embeds", - type_hint=Optional[torch.Tensor], - description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." - ), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - - def __init__(self): - super().__init__() - self.components["guider"] = CFGGuider() - self.components["controlnet_guider"] = CFGGuider() - self.components["scheduler"] = None - self.components["unet"] = None - self.components["controlnet"] = None - control_image_processor = VaeImageProcessor(do_convert_rgb=True, do_normalize=False) - self.auxiliaries["control_image_processor"] = control_image_processor - - def check_inputs(self, pipeline, data): - - num_channels_unet = pipeline.unet.config.in_channels - if num_channels_unet == 9: - # default case for runwayml/stable-diffusion-inpainting - if data.mask is None or data.masked_image_latents is None: - raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") - num_channels_latents = data.latents.shape[1] - num_channels_mask = data.mask.shape[1] - num_channels_masked_image = data.masked_image_latents.shape[1] - if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: - raise ValueError( - f"Incorrect configuration settings! The config of `pipeline.unet`: {pipeline.unet.config} expects" - f" {pipeline.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" - f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" - " `pipeline.unet` or your `mask_image` or `image` input." - ) - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - - data = self.get_block_state(state) - self.check_inputs(pipeline, data) - - data.num_channels_unet = pipeline.unet.config.in_channels - - # (1) prepare controlnet inputs - - data.device = pipeline._execution_device - - data.height, data.width = data.latents.shape[-2:] - data.height = data.height * pipeline.vae_scale_factor - data.width = data.width * pipeline.vae_scale_factor - - controlnet = pipeline.controlnet._orig_mod if is_compiled_module(pipeline.controlnet) else pipeline.controlnet - - # (1.1) - # control_guidance_start/control_guidance_end (align format) - if not isinstance(data.control_guidance_start, list) and isinstance(data.control_guidance_end, list): - data.control_guidance_start = len(data.control_guidance_end) * [data.control_guidance_start] - elif not isinstance(data.control_guidance_end, list) and isinstance(data.control_guidance_start, list): - data.control_guidance_end = len(data.control_guidance_start) * [data.control_guidance_end] - elif not isinstance(data.control_guidance_start, list) and not isinstance(data.control_guidance_end, list): - mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 - data.control_guidance_start, data.control_guidance_end = ( - mult * [data.control_guidance_start], - mult * [data.control_guidance_end], - ) - - # (1.2) - # controlnet_conditioning_scale (align format) - if isinstance(controlnet, MultiControlNetModel) and isinstance(data.controlnet_conditioning_scale, float): - data.controlnet_conditioning_scale = [data.controlnet_conditioning_scale] * len(controlnet.nets) - - # (1.3) - # global_pool_conditions - data.global_pool_conditions = ( - controlnet.config.global_pool_conditions - if isinstance(controlnet, ControlNetModel) - else controlnet.nets[0].config.global_pool_conditions - ) - # (1.4) - # guess_mode - data.guess_mode = data.guess_mode or data.global_pool_conditions - - # (1.5) - # control_image - if isinstance(controlnet, ControlNetModel): - data.control_image = pipeline.prepare_control_image( - image=data.control_image, - width=data.width, - height=data.height, - batch_size=data.batch_size * data.num_images_per_prompt, - num_images_per_prompt=data.num_images_per_prompt, - device=data.device, - dtype=controlnet.dtype, - crops_coords=data.crops_coords, - ) - elif isinstance(controlnet, MultiControlNetModel): - control_images = [] - - for control_image_ in data.control_image: - control_image = pipeline.prepare_control_image( - image=control_image_, - width=data.width, - height=data.height, - batch_size=data.batch_size * data.num_images_per_prompt, - num_images_per_prompt=data.num_images_per_prompt, - device=data.device, - dtype=controlnet.dtype, - crops_coords=data.crops_coords, - ) - - control_images.append(control_image) - - data.control_image = control_images - else: - assert False - - # (1.6) - # controlnet_keep - data.controlnet_keep = [] - for i in range(len(data.timesteps)): - keeps = [ - 1.0 - float(i / len(data.timesteps) < s or (i + 1) / len(data.timesteps) > e) - for s, e in zip(data.control_guidance_start, data.control_guidance_end) - ] - data.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) - - # (2) Prepare conditional inputs for unet using the guider - # adding default guider arguments: disable_guidance, guidance_scale, guidance_rescale - data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False - data.guider_kwargs = data.guider_kwargs or {} - data.guider_kwargs = { - **data.guider_kwargs, - "disable_guidance": data.disable_guidance, - "guidance_scale": data.guidance_scale, - "guidance_rescale": data.guidance_rescale, - "batch_size": data.batch_size * data.num_images_per_prompt, - } - pipeline.guider.set_guider(pipeline, data.guider_kwargs) - data.prompt_embeds = pipeline.guider.prepare_input( - data.prompt_embeds, - data.negative_prompt_embeds, - ) - data.add_time_ids = pipeline.guider.prepare_input( - data.add_time_ids, - data.negative_add_time_ids, - ) - data.pooled_prompt_embeds = pipeline.guider.prepare_input( - data.pooled_prompt_embeds, - data.negative_pooled_prompt_embeds, - ) - if data.num_channels_unet == 9: - data.mask = pipeline.guider.prepare_input(data.mask, data.mask) - data.masked_image_latents = pipeline.guider.prepare_input(data.masked_image_latents, data.masked_image_latents) - - data.added_cond_kwargs = { - "text_embeds": data.pooled_prompt_embeds, - "time_ids": data.add_time_ids, - } - - if data.ip_adapter_embeds is not None: - data.ip_adapter_embeds = pipeline.guider.prepare_input(data.ip_adapter_embeds, data.negative_ip_adapter_embeds) - data.added_cond_kwargs["image_embeds"] = data.ip_adapter_embeds - - # (3) Prepare conditional inputs for controlnet using the guider - data.controlnet_disable_guidance = True if data.disable_guidance or data.guess_mode else False - data.controlnet_guider_kwargs = data.guider_kwargs or {} - data.controlnet_guider_kwargs = { - **data.controlnet_guider_kwargs, - "disable_guidance": data.controlnet_disable_guidance, - "guidance_scale": data.guidance_scale, - "guidance_rescale": data.guidance_rescale, - "batch_size": data.batch_size * data.num_images_per_prompt, - } - pipeline.controlnet_guider.set_guider(pipeline, data.controlnet_guider_kwargs) - data.controlnet_prompt_embeds = pipeline.controlnet_guider.prepare_input(data.prompt_embeds) - data.controlnet_added_cond_kwargs = { - "text_embeds": pipeline.controlnet_guider.prepare_input(data.pooled_prompt_embeds), - "time_ids": pipeline.controlnet_guider.prepare_input(data.add_time_ids), - } - data.control_image = pipeline.controlnet_guider.prepare_input(data.control_image, data.control_image) - - # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - data.extra_step_kwargs = pipeline.prepare_extra_step_kwargs(data.generator, data.eta) - data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0) - - # (5) Denoise loop - with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: - for i, t in enumerate(data.timesteps): - # prepare latents for unet using the guider - data.latent_model_input = pipeline.guider.prepare_input(data.latents, data.latents) - - # prepare latents for controlnet using the guider - data.control_model_input = pipeline.controlnet_guider.prepare_input(data.latents, data.latents) - - if isinstance(data.controlnet_keep[i], list): - data.cond_scale = [c * s for c, s in zip(data.controlnet_conditioning_scale, data.controlnet_keep[i])] - else: - data.controlnet_cond_scale = data.controlnet_conditioning_scale - if isinstance(data.controlnet_cond_scale, list): - data.controlnet_cond_scale = data.controlnet_cond_scale[0] - data.cond_scale = data.controlnet_cond_scale * data.controlnet_keep[i] - - data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet( - pipeline.scheduler.scale_model_input(data.control_model_input, t), - t, - encoder_hidden_states=data.controlnet_prompt_embeds, - controlnet_cond=data.control_image, - conditioning_scale=data.cond_scale, - guess_mode=data.guess_mode, - added_cond_kwargs=data.controlnet_added_cond_kwargs, - return_dict=False, - ) - - # when we apply guidance for unet, but not for controlnet: - # add 0 to the unconditional batch - data.down_block_res_samples = pipeline.guider.prepare_input( - data.down_block_res_samples, [torch.zeros_like(d) for d in data.down_block_res_samples] - ) - data.mid_block_res_sample = pipeline.guider.prepare_input( - data.mid_block_res_sample, torch.zeros_like(data.mid_block_res_sample) - ) - - data.latent_model_input = pipeline.scheduler.scale_model_input(data.latent_model_input, t) - if data.num_channels_unet == 9: - data.latent_model_input = torch.cat([data.latent_model_input, data.mask, data.masked_image_latents], dim=1) - - data.noise_pred = pipeline.unet( - data.latent_model_input, - t, - encoder_hidden_states=data.prompt_embeds, - timestep_cond=data.timestep_cond, - cross_attention_kwargs=data.cross_attention_kwargs, - added_cond_kwargs=data.added_cond_kwargs, - down_block_additional_residuals=data.down_block_res_samples, - mid_block_additional_residual=data.mid_block_res_sample, - return_dict=False, - )[0] - # perform guidance - data.noise_pred = pipeline.guider.apply_guidance(data.noise_pred, timestep=t, latents=data.latents) - # compute the previous noisy sample x_t -> x_t-1 - data.latents_dtype = data.latents.dtype - data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0] - if data.latents.dtype != data.latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - data.latents = data.latents.to(data.latents_dtype) - - - if data.num_channels_unet == 4 and data.mask is not None and data.image_latents is not None: - data.init_latents_proper = data.image_latents - if i < len(data.timesteps) - 1: - data.noise_timestep = data.timesteps[i + 1] - data.init_latents_proper = pipeline.scheduler.add_noise( - data.init_latents_proper, data.noise, torch.tensor([data.noise_timestep]) - ) - - data.latents = (1 - data.mask) * data.init_latents_proper + data.mask * data.latents - - if i == len(data.timesteps) - 1 or ((i + 1) > data.num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): - progress_bar.update() - - pipeline.guider.reset_guider(pipeline) - pipeline.controlnet_guider.reset_guider(pipeline) - - self.add_block_state(state, data) - - return pipeline, state - - -class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): - expected_components = ["unet", "controlnet", "scheduler", "guider", "controlnet_guider"] - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return " The denoising step for the controlnet union model, works for inpainting, image-to-image, and text-to-image tasks" - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam( - "control_image", - required=True, - type_hint=PipelineImageInput, - description="The ControlNet input condition to provide guidance to the unet for generation. If passed as torch.Tensor, it is used as-is. PIL.Image.Image inputs are accepted and default to image dimensions. For multiple ControlNets, pass images as a list for proper batching."), - InputParam( - "control_guidance_start", - default=0.0, - type_hint=Union[float, List[float]], - description="The percentage of total steps at which the ControlNet starts applying."), - InputParam( - "control_guidance_end", - default=1.0, - type_hint=Union[float, List[float]], - description="The percentage of total steps at which the ControlNet stops applying."), - InputParam( - "control_mode", - required=True, - type_hint=List[int], - description="The control mode for union controlnet, 0 for openpose, 1 for depth, 2 for hed/pidi/scribble/ted, 3 for canny/lineart/anime_lineart/mlsd, 4 for normal and 5 for segment" - ), - InputParam( - "controlnet_conditioning_scale", - default=1.0, - type_hint=Union[float, List[float]], - description="Scale factor for ControlNet outputs before adding to unet residual. For multiple ControlNets, can be set as a list of scales." - ), - InputParam( - "guess_mode", - default=False, - type_hint=bool, - description="Enables ControlNet encoder to recognize input image content without prompts. Recommended guidance_scale: 3.0-5.0." - ), - InputParam( - "num_images_per_prompt", - default=1, - type_hint=int, - description="The number of images to generate per prompt." - ), - InputParam( - "guidance_scale", - default=5.0, - type_hint=float, - description="Guidance scale as defined in Classifier-Free Diffusion Guidance. Higher values encourage images closely linked to the text prompt, potentially at the expense of image quality. Enabled when > 1."), - InputParam( - "guidance_rescale", - default=0.0, - type_hint=float, - description="Guidance rescale factor (φ) to fix overexposure when using zero terminal SNR, as proposed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed'."), - InputParam( - "cross_attention_kwargs", - default=None, - type_hint=Optional[Dict[str, Any]], - description="Optional kwargs dictionary passed to the AttentionProcessor."), - InputParam( - "generator", - default=None, - type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], - description="One or a list of torch generator(s) to make generation deterministic."), - InputParam( - "eta", - default=0.0, - type_hint=float, - description="Parameter η in the DDIM paper. Only applies to DDIMScheduler, ignored for others."), - InputParam( - "guider_kwargs", - default=None, - type_hint=Optional[Dict[str, Any]], - description="Optional kwargs dictionary passed to the Guider."), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." - ), - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), - InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "negative_prompt_embeds", - type_hint=Optional[torch.Tensor], - description="The negative prompt embeddings used to condition the denoising process. Can be generated in text_encoder step. See: https://github.com/huggingface/diffusers/issues/4208" - ), - InputParam( - "add_time_ids", - required=True, - type_hint=torch.Tensor, - description="The time ids used to condition the denoising process. Can be generated in prepare_additional_conditioning step." - ), - InputParam( - "negative_add_time_ids", - type_hint=Optional[torch.Tensor], - description="The negative time ids used to condition the denoising process. Can be generated in prepare_additional_conditioning step. " - ), - InputParam( - "pooled_prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The pooled prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "negative_pooled_prompt_embeds", - type_hint=Optional[torch.Tensor], - description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. See: https://github.com/huggingface/diffusers/issues/4208" - ), - InputParam( - "timestep_cond", - type_hint=Optional[torch.Tensor], - description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." - ), - InputParam( - "mask", - type_hint=Optional[torch.Tensor], - description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "masked_image_latents", - type_hint=Optional[torch.Tensor], - description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "noise", - type_hint=Optional[torch.Tensor], - description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." - ), - InputParam( - "image_latents", - type_hint=Optional[torch.Tensor], - description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "crops_coords", - type_hint=Optional[Tuple[int]], - description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "ip_adapter_embeds", - type_hint=Optional[torch.Tensor], - description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." - ), - InputParam( - "negative_ip_adapter_embeds", - type_hint=Optional[torch.Tensor], - description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." - ), - ] - - @property - def intermediates_outputs(self) -> List[str]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - - def __init__(self): - super().__init__() - self.components["guider"] = CFGGuider() - self.components["controlnet_guider"] = CFGGuider() - self.components["scheduler"] = None - self.components["unet"] = None - self.components["controlnet"] = None - control_image_processor = VaeImageProcessor(do_convert_rgb=True, do_normalize=False) - self.auxiliaries["control_image_processor"] = control_image_processor - - def check_inputs(self, pipeline, data): - - num_channels_unet = pipeline.unet.config.in_channels - if num_channels_unet == 9: - # default case for runwayml/stable-diffusion-inpainting - if data.mask is None or data.masked_image_latents is None: - raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") - num_channels_latents = data.latents.shape[1] - num_channels_mask = data.mask.shape[1] - num_channels_masked_image = data.masked_image_latents.shape[1] - if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: - raise ValueError( - f"Incorrect configuration settings! The config of `pipeline.unet`: {pipeline.unet.config} expects" - f" {pipeline.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" - f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" - " `pipeline.unet` or your `mask_image` or `image` input." - ) - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - self.check_inputs(pipeline, data) - - data.num_channels_unet = pipeline.unet.config.in_channels - - # (1) prepare controlnet inputs - data.device = pipeline._execution_device - data.height, data.width = data.latents.shape[-2:] - data.height = data.height * pipeline.vae_scale_factor - data.width = data.width * pipeline.vae_scale_factor - - controlnet = pipeline.controlnet._orig_mod if is_compiled_module(pipeline.controlnet) else pipeline.controlnet - - # (1.1) - # control guidance - if not isinstance(data.control_guidance_start, list) and isinstance(data.control_guidance_end, list): - data.control_guidance_start = len(data.control_guidance_end) * [data.control_guidance_start] - elif not isinstance(data.control_guidance_end, list) and isinstance(data.control_guidance_start, list): - data.control_guidance_end = len(data.control_guidance_start) * [data.control_guidance_end] - - # (1.2) - # global_pool_conditions & guess_mode - data.global_pool_conditions = controlnet.config.global_pool_conditions - data.guess_mode = data.guess_mode or data.global_pool_conditions - - # (1.3) - # control_type - data.num_control_type = controlnet.config.num_control_type - - # (1.4) - # control_type - if not isinstance(data.control_image, list): - data.control_image = [data.control_image] - - if not isinstance(data.control_mode, list): - data.control_mode = [data.control_mode] - - if len(data.control_image) != len(data.control_mode): - raise ValueError("Expected len(control_image) == len(control_type)") - - data.control_type = [0 for _ in range(data.num_control_type)] - for control_idx in data.control_mode: - data.control_type[control_idx] = 1 - - data.control_type = torch.Tensor(data.control_type) - - # (1.5) - # prepare control_image - for idx, _ in enumerate(data.control_image): - data.control_image[idx] = pipeline.prepare_control_image( - image=data.control_image[idx], - width=data.width, - height=data.height, - batch_size=data.batch_size * data.num_images_per_prompt, - num_images_per_prompt=data.num_images_per_prompt, - device=data.device, - dtype=controlnet.dtype, - crops_coords=data.crops_coords, - ) - data.height, data.width = data.control_image[idx].shape[-2:] - - - # (1.6) - # controlnet_keep - data.controlnet_keep = [] - for i in range(len(data.timesteps)): - data.controlnet_keep.append( - 1.0 - - float(i / len(data.timesteps) < data.control_guidance_start or (i + 1) / len(data.timesteps) > data.control_guidance_end) - ) - - # (2) Prepare conditional inputs for unet using the guider - # adding default guider arguments: disable_guidance, guidance_scale, guidance_rescale - data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False - data.guider_kwargs = data.guider_kwargs or {} - data.guider_kwargs = { - **data.guider_kwargs, - "disable_guidance": data.disable_guidance, - "guidance_scale": data.guidance_scale, - "guidance_rescale": data.guidance_rescale, - "batch_size": data.batch_size * data.num_images_per_prompt, - } - pipeline.guider.set_guider(pipeline, data.guider_kwargs) - data.prompt_embeds = pipeline.guider.prepare_input( - data.prompt_embeds, - data.negative_prompt_embeds, - ) - data.add_time_ids = pipeline.guider.prepare_input( - data.add_time_ids, - data.negative_add_time_ids, - ) - data.pooled_prompt_embeds = pipeline.guider.prepare_input( - data.pooled_prompt_embeds, - data.negative_pooled_prompt_embeds, - ) - - if data.num_channels_unet == 9: - data.mask = pipeline.guider.prepare_input(data.mask, data.mask) - data.masked_image_latents = pipeline.guider.prepare_input(data.masked_image_latents, data.masked_image_latents) - - data.added_cond_kwargs = { - "text_embeds": data.pooled_prompt_embeds, - "time_ids": data.add_time_ids, - } - - if data.ip_adapter_embeds is not None: - data.ip_adapter_embeds = pipeline.guider.prepare_input(data.ip_adapter_embeds, data.negative_ip_adapter_embeds) - data.added_cond_kwargs["image_embeds"] = data.ip_adapter_embeds - - # (3) Prepare conditional inputs for controlnet using the guider - data.controlnet_disable_guidance = True if data.disable_guidance or data.guess_mode else False - data.controlnet_guider_kwargs = data.guider_kwargs or {} - data.controlnet_guider_kwargs = { - **data.controlnet_guider_kwargs, - "disable_guidance": data.controlnet_disable_guidance, - "guidance_scale": data.guidance_scale, - "guidance_rescale": data.guidance_rescale, - "batch_size": data.batch_size * data.num_images_per_prompt, - } - pipeline.controlnet_guider.set_guider(pipeline, data.controlnet_guider_kwargs) - data.controlnet_prompt_embeds = pipeline.controlnet_guider.prepare_input(data.prompt_embeds) - data.controlnet_added_cond_kwargs = { - "text_embeds": pipeline.controlnet_guider.prepare_input(data.pooled_prompt_embeds), - "time_ids": pipeline.controlnet_guider.prepare_input(data.add_time_ids), - } - for idx, _ in enumerate(data.control_image): - data.control_image[idx] = pipeline.controlnet_guider.prepare_input(data.control_image[idx], data.control_image[idx]) - - data.control_type = ( - data.control_type.reshape(1, -1) - .to(data.device, dtype=data.prompt_embeds.dtype) - ) - repeat_by = data.batch_size * data.num_images_per_prompt // data.control_type.shape[0] - data.control_type = data.control_type.repeat_interleave(repeat_by, dim=0) - data.control_type = pipeline.controlnet_guider.prepare_input(data.control_type, data.control_type) - - # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - data.extra_step_kwargs = pipeline.prepare_extra_step_kwargs(data.generator, data.eta) - data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0) - - - with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: - for i, t in enumerate(data.timesteps): - # prepare latents for unet using the guider - data.latent_model_input = pipeline.guider.prepare_input(data.latents, data.latents) - - # prepare latents for controlnet using the guider - data.control_model_input = pipeline.controlnet_guider.prepare_input(data.latents, data.latents) - - if isinstance(data.controlnet_keep[i], list): - data.cond_scale = [c * s for c, s in zip(data.controlnet_conditioning_scale, data.controlnet_keep[i])] - else: - data.controlnet_cond_scale = data.controlnet_conditioning_scale - if isinstance(data.controlnet_cond_scale, list): - data.controlnet_cond_scale = data.controlnet_cond_scale[0] - data.cond_scale = data.controlnet_cond_scale * data.controlnet_keep[i] - - data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet( - pipeline.scheduler.scale_model_input(data.control_model_input, t), - t, - encoder_hidden_states=data.controlnet_prompt_embeds, - controlnet_cond=data.control_image, - control_type=data.control_type, - control_type_idx=data.control_mode, - conditioning_scale=data.cond_scale, - guess_mode=data.guess_mode, - added_cond_kwargs=data.controlnet_added_cond_kwargs, - return_dict=False, - ) - - # when we apply guidance for unet, but not for controlnet: - # add 0 to the unconditional batch - data.down_block_res_samples = pipeline.guider.prepare_input( - data.down_block_res_samples, [torch.zeros_like(d) for d in data.down_block_res_samples] - ) - data.mid_block_res_sample = pipeline.guider.prepare_input( - data.mid_block_res_sample, torch.zeros_like(data.mid_block_res_sample) - ) - - data.latent_model_input = pipeline.scheduler.scale_model_input(data.latent_model_input, t) - if data.num_channels_unet == 9: - data.latent_model_input = torch.cat([data.latent_model_input, data.mask, data.masked_image_latents], dim=1) - - data.noise_pred = pipeline.unet( - data.latent_model_input, - t, - encoder_hidden_states=data.prompt_embeds, - timestep_cond=data.timestep_cond, - cross_attention_kwargs=data.cross_attention_kwargs, - added_cond_kwargs=data.added_cond_kwargs, - down_block_additional_residuals=data.down_block_res_samples, - mid_block_additional_residual=data.mid_block_res_sample, - return_dict=False, - )[0] - # perform guidance - data.noise_pred = pipeline.guider.apply_guidance(data.noise_pred, timestep=t, latents=data.latents) - # compute the previous noisy sample x_t -> x_t-1 - data.latents_dtype = data.latents.dtype - data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0] - if data.latents.dtype != data.latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - data.latents = data.latents.to(data.latents_dtype) - - if data.num_channels_unet == 9 and data.mask is not None and data.image_latents is not None: - data.init_latents_proper = data.image_latents - if i < len(data.timesteps) - 1: - data.noise_timestep = data.timesteps[i + 1] - data.init_latents_proper = pipeline.scheduler.add_noise( - data.init_latents_proper, data.noise, torch.tensor([data.noise_timestep]) - ) - - data.latents = (1 - data.mask) * data.init_latents_proper + data.mask * data.latents - - if i == len(data.timesteps) - 1 or ((i + 1) > data.num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): - progress_bar.update() - - pipeline.guider.reset_guider(pipeline) - pipeline.controlnet_guider.reset_guider(pipeline) - - self.add_block_state(state, data) - - return pipeline, state - - -class StableDiffusionXLDecodeLatentsStep(PipelineBlock): - expected_components = ["vae"] - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return "Step that decodes the denoised latents into images" - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam( - "output_type", - type_hint=str, - default="pil", - description="The output format of the generated image. Choose between PIL (PIL.Image.Image), torch.Tensor or np.array." - ), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [InputParam("latents", required=True, type_hint=torch.Tensor, description="The denoised latents from the denoising step")] - - @property - def intermediates_outputs(self) -> List[str]: - return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array")] - - def __init__(self): - super().__init__() - self.components["vae"] = None - self.auxiliaries["image_processor"] = VaeImageProcessor(vae_scale_factor=8) - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - - if not data.output_type == "latent": - # make sure the VAE is in float32 mode, as it overflows in float16 - data.needs_upcasting = pipeline.vae.dtype == torch.float16 and pipeline.vae.config.force_upcast - - if data.needs_upcasting: - pipeline.upcast_vae() - data.latents = data.latents.to(next(iter(pipeline.vae.post_quant_conv.parameters())).dtype) - elif data.latents.dtype != pipeline.vae.dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - pipeline.vae = pipeline.vae.to(data.latents.dtype) - - # unscale/denormalize the latents - # denormalize with the mean and std if available and not None - data.has_latents_mean = ( - hasattr(pipeline.vae.config, "latents_mean") and pipeline.vae.config.latents_mean is not None - ) - data.has_latents_std = ( - hasattr(pipeline.vae.config, "latents_std") and pipeline.vae.config.latents_std is not None - ) - if data.has_latents_mean and data.has_latents_std: - data.latents_mean = ( - torch.tensor(pipeline.vae.config.latents_mean).view(1, 4, 1, 1).to(data.latents.device, data.latents.dtype) - ) - data.latents_std = ( - torch.tensor(pipeline.vae.config.latents_std).view(1, 4, 1, 1).to(data.latents.device, data.latents.dtype) - ) - data.latents = data.latents * data.latents_std / pipeline.vae.config.scaling_factor + data.latents_mean - else: - data.latents = data.latents / pipeline.vae.config.scaling_factor - - data.images = pipeline.vae.decode(data.latents, return_dict=False)[0] - - # cast back to fp16 if needed - if data.needs_upcasting: - pipeline.vae.to(dtype=torch.float16) - else: - data.images = data.latents - - # apply watermark if available - if hasattr(pipeline, "watermark") and pipeline.watermark is not None: - data.images = pipeline.watermark.apply_watermark(data.images) - - data.images = pipeline.image_processor.postprocess(data.images, output_type=data.output_type) - - self.add_block_state(state, data) - - return pipeline, state - - -class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return "A post-processing step that overlays the mask on the image (inpainting task only).\n" + \ - "only needed when you are using the `padding_mask_crop` option when pre-processing the image and mask" - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam( - "image", - type_hint=PipelineImageInput, - required=True, - description="The image(s) to modify with the pipeline, for img2img or inpainting task. When using for inpainting task, parts of the image will be masked out with `mask_image` and repainted according to `prompt`." - ), - InputParam( - "mask_image", - type_hint=PipelineImageInput, - required=True, - description="The mask image(s) to use for inpainting, white pixels in the mask will be repainted, while black pixels will be preserved. If mask_image is a PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the expected shape would be (B, H, W, 1). Must be a `PIL.Image.Image`" - ), - InputParam( - "padding_mask_crop", - type_hint=Optional[Tuple[int, int]], - default=None, - description="The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied. If set, it will find a rectangular region with the same aspect ratio as the image that contains all masked areas, then expand that area by this margin. The image and mask_image are cropped to this expanded area before resizing to the original size for inpainting. Useful when the masked area is small in a large image with irrelevant background information." - ), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam("images", required=True, type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images from the decode step"), - InputParam("crops_coords", required=True, type_hint=Tuple[int, int], description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.") - ] - - @property - def intermediates_outputs(self) -> List[str]: - return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images with the mask overlayed")] - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - - if data.padding_mask_crop is not None and data.crops_coords is not None: - data.images = [pipeline.image_processor.apply_overlay(data.mask_image, data.image, i, data.crops_coords) for i in data.images] - - self.add_block_state(state, data) - - return pipeline, state - - -class StableDiffusionXLOutputStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return "final step to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or a plain tuple." - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [(InputParam("return_dict", type_hint=bool, default=True, description="Whether or not to return a StableDiffusionXLPipelineOutput instead of a plain tuple."))] - - @property - def intermediates_inputs(self) -> List[str]: - return [InputParam("images", required=True, type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images from the decode step.")] - - @property - def intermediates_outputs(self) -> List[str]: - return [OutputParam("images", description="The final images output, can be a tuple or a `StableDiffusionXLPipelineOutput`")] - - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - - if not data.return_dict: - data.images = (data.images,) - else: - data.images = StableDiffusionXLPipelineOutput(images=data.images) - self.add_block_state(state, data) - return pipeline, state - - -# Encode -class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLInpaintVaeEncoderStep, StableDiffusionXLVaeEncoderStep] - block_names = ["inpaint", "img2img"] - block_trigger_inputs = ["mask_image", "image"] - - @property - def description(self): - return "Vae encoder step that encode the image inputs into their latent representations.\n" + \ - "This is an auto pipeline block that works for both inpainting and img2img tasks.\n" + \ - " - `StableDiffusionXLInpaintVaeEncoderStep` (inpaint) is used when both `mask_image` and `image` are provided.\n" + \ - " - `StableDiffusionXLVaeEncoderStep` (img2img) is used when only `image` is provided." - - -# Before denoise -class StableDiffusionXLBeforeDenoiseStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLInputStep, StableDiffusionXLSetTimestepsStep, StableDiffusionXLPrepareLatentsStep, StableDiffusionXLPrepareAdditionalConditioningStep] - block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"] - - @property - def description(self): - return "Before denoise step that prepare the inputs for the denoise step.\n" + \ - "This is a sequential pipeline blocks:\n" + \ - " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ - " - `StableDiffusionXLSetTimestepsStep` is used to set the timesteps\n" + \ - " - `StableDiffusionXLPrepareLatentsStep` is used to prepare the latents\n" + \ - " - `StableDiffusionXLPrepareAdditionalConditioningStep` is used to prepare the additional conditioning" - -class StableDiffusionXLImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLInputStep, StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLImg2ImgPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep] - block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"] - - @property - def description(self): - return "Before denoise step that prepare the inputs for the denoise step for img2img task.\n" + \ - "This is a sequential pipeline blocks:\n" + \ - " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ - " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" + \ - " - `StableDiffusionXLImg2ImgPrepareLatentsStep` is used to prepare the latents\n" + \ - " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning" - -class StableDiffusionXLInpaintBeforeDenoiseStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLInputStep, StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLInpaintPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep] - block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"] - - @property - def description(self): - return "Before denoise step that prepare the inputs for the denoise step for inpainting task.\n" + \ - "This is a sequential pipeline blocks:\n" + \ - " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ - " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" + \ - " - `StableDiffusionXLInpaintPrepareLatentsStep` is used to prepare the latents\n" + \ - " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning" - - -class StableDiffusionXLAutoBeforeDenoiseStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLInpaintBeforeDenoiseStep, StableDiffusionXLImg2ImgBeforeDenoiseStep, StableDiffusionXLBeforeDenoiseStep] - block_names = ["inpaint", "img2img", "text2img"] - block_trigger_inputs = ["mask", "image_latents", None] - - @property - def description(self): - return "Before denoise step that prepare the inputs for the denoise step.\n" + \ - "This is an auto pipeline block that works for text2img, img2img and inpainting tasks.\n" + \ - " - `StableDiffusionXLInpaintBeforeDenoiseStep` (inpaint) is used when both `mask` and `image_latents` are provided.\n" + \ - " - `StableDiffusionXLImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n" + \ - " - `StableDiffusionXLBeforeDenoiseStep` (text2img) is used when both `image_latents` and `mask` are not provided." - - -# Denoise -class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLControlNetUnionDenoiseStep, StableDiffusionXLControlNetDenoiseStep, StableDiffusionXLDenoiseStep] - block_names = ["controlnet_union", "controlnet", "unet"] - block_trigger_inputs = ["control_mode", "control_image", None] - - @property - def description(self): - return "Denoise step that denoise the latents.\n" + \ - "This is an auto pipeline block that works for controlnet, controlnet_union and no controlnet.\n" + \ - " - `StableDiffusionXLControlNetUnionDenoiseStep` (controlnet_union) is used when both `control_mode` and `control_image` are provided.\n" + \ - " - `StableDiffusionXLControlNetDenoiseStep` (controlnet) is used when `control_image` is provided.\n" + \ - " - `StableDiffusionXLDenoiseStep` (unet only) is used when both `control_mode` and `control_image` are not provided." - -# After denoise - -class StableDiffusionXLDecodeStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLDecodeLatentsStep, StableDiffusionXLOutputStep] - block_names = ["decode", "output"] - - @property - def description(self): - return """Decode step that decode the denoised latents into images outputs. -This is a sequential pipeline blocks: - - `StableDiffusionXLDecodeLatentsStep` is used to decode the denoised latents into images - - `StableDiffusionXLOutputStep` is used to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or a plain tuple.""" - - -class StableDiffusionXLInpaintDecodeStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLDecodeLatentsStep, StableDiffusionXLInpaintOverlayMaskStep, StableDiffusionXLOutputStep] - block_names = ["decode", "mask_overlay", "output"] - - @property - def description(self): - return "Inpaint decode step that decode the denoised latents into images outputs.\n" + \ - "This is a sequential pipeline blocks:\n" + \ - " - `StableDiffusionXLDecodeLatentsStep` is used to decode the denoised latents into images\n" + \ - " - `StableDiffusionXLInpaintOverlayMaskStep` is used to overlay the mask on the image\n" + \ - " - `StableDiffusionXLOutputStep` is used to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or a plain tuple." - - - -class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLInpaintDecodeStep, StableDiffusionXLDecodeStep] - block_names = ["inpaint", "non-inpaint"] - block_trigger_inputs = ["padding_mask_crop", None] - - @property - def description(self): - return "Decode step that decode the denoised latents into images outputs.\n" + \ - "This is an auto pipeline block that works for inpainting and non-inpainting tasks.\n" + \ - " - `StableDiffusionXLInpaintDecodeStep` (inpaint) is used when `padding_mask_crop` is provided.\n" + \ - " - `StableDiffusionXLDecodeStep` (non-inpaint) is used when `padding_mask_crop` is not provided." - -class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLIPAdapterStep] - block_names = ["ip_adapter"] - block_trigger_inputs = ["ip_adapter_image"] - - @property - def description(self): - return "Run IP Adapter step if `ip_adapter_image` is provided." - -class StableDiffusionXLAutoPipeline(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep, StableDiffusionXLAutoBeforeDenoiseStep, StableDiffusionXLAutoDenoiseStep, StableDiffusionXLAutoDecodeStep] - block_names = ["text_encoder", "ip_adapter", "image_encoder", "before_denoise", "denoise", "decode"] - - @property - def description(self): - return "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion XL.\n" + \ - "- for image-to-image generation, you need to provide either `image` or `image_latents`\n" + \ - "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n" + \ - "- to run the controlnet workflow, you need to provide `control_image`\n" + \ - "- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n" + \ - "- to run the ip_adapter workflow, you need to provide `ip_adapter_image`\n" + \ - "- for text-to-image generation, all you need to provide is `prompt`" - -# block mapping -TEXT2IMAGE_BLOCKS = OrderedDict([ - ("text_encoder", StableDiffusionXLTextEncoderStep), - ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), - ("input", StableDiffusionXLInputStep), - ("set_timesteps", StableDiffusionXLSetTimestepsStep), - ("prepare_latents", StableDiffusionXLPrepareLatentsStep), - ("prepare_add_cond", StableDiffusionXLPrepareAdditionalConditioningStep), - ("denoise", StableDiffusionXLDenoiseStep), - ("decode", StableDiffusionXLDecodeStep) -]) - -IMAGE2IMAGE_BLOCKS = OrderedDict([ - ("text_encoder", StableDiffusionXLTextEncoderStep), - ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), - ("image_encoder", StableDiffusionXLVaeEncoderStep), - ("input", StableDiffusionXLInputStep), - ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), - ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep), - ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), - ("denoise", StableDiffusionXLDenoiseStep), - ("decode", StableDiffusionXLDecodeStep) -]) - -INPAINT_BLOCKS = OrderedDict([ - ("text_encoder", StableDiffusionXLTextEncoderStep), - ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), - ("image_encoder", StableDiffusionXLInpaintVaeEncoderStep), - ("input", StableDiffusionXLInputStep), - ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), - ("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep), - ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), - ("denoise", StableDiffusionXLDenoiseStep), - ("decode", StableDiffusionXLInpaintDecodeStep) -]) - -CONTROLNET_BLOCKS = OrderedDict([ - ("denoise", StableDiffusionXLControlNetDenoiseStep), -]) - -CONTROLNET_UNION_BLOCKS = OrderedDict([ - ("denoise", StableDiffusionXLControlNetUnionDenoiseStep), -]) - -IP_ADAPTER_BLOCKS = OrderedDict([ - ("ip_adapter", StableDiffusionXLIPAdapterStep), -]) - -AUTO_BLOCKS = OrderedDict([ - ("text_encoder", StableDiffusionXLTextEncoderStep), - ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), - ("image_encoder", StableDiffusionXLAutoVaeEncoderStep), - ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), - ("denoise", StableDiffusionXLAutoDenoiseStep), - ("decode", StableDiffusionXLAutoDecodeStep) -]) - -AUTO_CORE_BLOCKS = OrderedDict([ - ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), - ("denoise", StableDiffusionXLAutoDenoiseStep), -]) - - -SDXL_SUPPORTED_BLOCKS = { - "text2img": TEXT2IMAGE_BLOCKS, - "img2img": IMAGE2IMAGE_BLOCKS, - "inpaint": INPAINT_BLOCKS, - "controlnet": CONTROLNET_BLOCKS, - "controlnet_union": CONTROLNET_UNION_BLOCKS, - "ip_adapter": IP_ADAPTER_BLOCKS, - "auto": AUTO_BLOCKS -} - - -class StableDiffusionXLModularPipeline( - ModularPipeline, - StableDiffusionMixin, - TextualInversionLoaderMixin, - StableDiffusionXLLoraLoaderMixin, - ModularIPAdapterMixin, -): - @property - def default_sample_size(self): - default_sample_size = 128 - if hasattr(self, "unet") and self.unet is not None: - default_sample_size = self.unet.config.sample_size - return default_sample_size - - @property - def vae_scale_factor(self): - vae_scale_factor = 8 - if hasattr(self, "vae") and self.vae is not None: - vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - return vae_scale_factor - - @property - def num_channels_unet(self): - num_channels_unet = 4 - if hasattr(self, "unet") and self.unet is not None: - num_channels_unet = self.unet.config.in_channels - return num_channels_unet - - @property - def num_channels_latents(self): - num_channels_latents = 4 - if hasattr(self, "vae") and self.vae is not None: - num_channels_latents = self.vae.config.latent_channels - return num_channels_latents - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids - def _get_add_time_ids( - self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None - ): - add_time_ids = list(original_size + crops_coords_top_left + target_size) - - passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim - ) - expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features - - if expected_add_embed_dim != passed_add_embed_dim: - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." - ) - - add_time_ids = torch.tensor([add_time_ids], dtype=dtype) - return add_time_ids - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids - def _get_add_time_ids_img2img( - self, - original_size, - crops_coords_top_left, - target_size, - aesthetic_score, - negative_aesthetic_score, - negative_original_size, - negative_crops_coords_top_left, - negative_target_size, - dtype, - text_encoder_projection_dim=None, - ): - if self.config.requires_aesthetics_score: - add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) - add_neg_time_ids = list( - negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) - ) - else: - add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) - - passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim - ) - expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features - - if ( - expected_add_embed_dim > passed_add_embed_dim - and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim - ): - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." - ) - elif ( - expected_add_embed_dim < passed_add_embed_dim - and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim - ): - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." - ) - elif expected_add_embed_dim != passed_add_embed_dim: - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." - ) - - add_time_ids = torch.tensor([add_time_ids], dtype=dtype) - add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) - - return add_time_ids, add_neg_time_ids - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states - else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) - - return image_embeds, uncond_image_embeds - - # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image - # 1. return image without apply any guidance - # 2. add crops_coords and resize_mode to preprocess() - def prepare_control_image( - self, - image, - width, - height, - batch_size, - num_images_per_prompt, - device, - dtype, - crops_coords=None, - ): - if crops_coords is not None: - image = self.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) - else: - image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) - image_batch_size = image.shape[0] - - if image_batch_size == 1: - repeat_by = batch_size - else: - # image batch size is the same as prompt batch size - repeat_by = num_images_per_prompt - - image = image.repeat_interleave(repeat_by, dim=0) - - image = image.to(device=device, dtype=dtype) - - return image - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt - def encode_prompt( - self, - prompt: str, - prompt_2: Optional[str] = None, - device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, - do_classifier_free_guidance: bool = True, - negative_prompt: Optional[str] = None, - negative_prompt_2: Optional[str] = None, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - pooled_prompt_embeds: Optional[torch.Tensor] = None, - negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, - lora_scale: Optional[float] = None, - clip_skip: Optional[int] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in both text-encoders - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - negative_prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and - `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - pooled_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` - input argument. - lora_scale (`float`, *optional*): - A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - clip_skip (`int`, *optional*): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - """ - device = device or self._execution_device - - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if self.text_encoder is not None: - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) - - if self.text_encoder_2 is not None: - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - else: - scale_lora_layers(self.text_encoder_2, lora_scale) - - prompt = [prompt] if isinstance(prompt, str) else prompt - - if prompt is not None: - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - # Define tokenizers and text encoders - tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] - text_encoders = ( - [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] - ) - - if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - - # textual inversion: process multi-vector tokens if necessary - prompt_embeds_list = [] - prompts = [prompt, prompt_2] - for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, tokenizer) - - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {tokenizer.model_max_length} tokens: {removed_text}" - ) - - prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) - - # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] - if clip_skip is None: - prompt_embeds = prompt_embeds.hidden_states[-2] - else: - # "2" because SDXL always indexes from the penultimate layer. - prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] - - prompt_embeds_list.append(prompt_embeds) - - prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) - - # get unconditional embeddings for classifier free guidance - zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt - if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: - negative_prompt_embeds = torch.zeros_like(prompt_embeds) - negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) - elif do_classifier_free_guidance and negative_prompt_embeds is None: - negative_prompt = negative_prompt or "" - negative_prompt_2 = negative_prompt_2 or negative_prompt - - # normalize str to list - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_2 = ( - batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 - ) - - uncond_tokens: List[str] - if prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = [negative_prompt, negative_prompt_2] - - negative_prompt_embeds_list = [] - for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): - if isinstance(self, TextualInversionLoaderMixin): - negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) - - max_length = prompt_embeds.shape[1] - uncond_input = tokenizer( - negative_prompt, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - negative_prompt_embeds = text_encoder( - uncond_input.input_ids.to(device), - output_hidden_states=True, - ) - # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] - negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] - - negative_prompt_embeds_list.append(negative_prompt_embeds) - - negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - - if self.text_encoder_2 is not None: - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) - else: - prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - if self.text_encoder_2 is not None: - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) - else: - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) - if do_classifier_free_guidance: - negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) - - if self.text_encoder is not None: - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) - - if self.text_encoder_2 is not None: - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder_2, lora_scale) - - return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds - def prepare_ip_adapter_image_embeds( - self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance - ): - image_embeds = [] - if do_classifier_free_guidance: - negative_image_embeds = [] - if ip_adapter_image_embeds is None: - if not isinstance(ip_adapter_image, list): - ip_adapter_image = [ip_adapter_image] - - if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): - raise ValueError( - f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." - ) - - for single_ip_adapter_image, image_proj_layer in zip( - ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers - ): - output_hidden_state = not isinstance(image_proj_layer, ImageProjection) - single_image_embeds, single_negative_image_embeds = self.encode_image( - single_ip_adapter_image, device, 1, output_hidden_state - ) - - image_embeds.append(single_image_embeds[None, :]) - if do_classifier_free_guidance: - negative_image_embeds.append(single_negative_image_embeds[None, :]) - else: - for single_image_embeds in ip_adapter_image_embeds: - if do_classifier_free_guidance: - single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - negative_image_embeds.append(single_negative_image_embeds) - image_embeds.append(single_image_embeds) - - ip_adapter_image_embeds = [] - for i, single_image_embeds in enumerate(image_embeds): - single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) - if do_classifier_free_guidance: - single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) - - single_image_embeds = single_image_embeds.to(device=device) - ip_adapter_image_embeds.append(single_image_embeds) - - return ip_adapter_image_embeds - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps - def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): - # get the original timestep using init_timestep - if denoising_start is None: - init_timestep = min(int(num_inference_steps * strength), num_inference_steps) - t_start = max(num_inference_steps - init_timestep, 0) - - timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] - if hasattr(self.scheduler, "set_begin_index"): - self.scheduler.set_begin_index(t_start * self.scheduler.order) - - return timesteps, num_inference_steps - t_start - - else: - # Strength is irrelevant if we directly request a timestep to start at; - # that is, strength is determined by the denoising_start instead. - discrete_timestep_cutoff = int( - round( - self.scheduler.config.num_train_timesteps - - (denoising_start * self.scheduler.config.num_train_timesteps) - ) - ) - - num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item() - if self.scheduler.order == 2 and num_inference_steps % 2 == 0: - # if the scheduler is a 2nd order scheduler we might have to do +1 - # because `num_inference_steps` might be even given that every timestep - # (except the highest one) is duplicated. If `num_inference_steps` is even it would - # mean that we cut the timesteps in the middle of the denoising step - # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 - # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler - num_inference_steps = num_inference_steps + 1 - - # because t_n+1 >= t_n, we slice the timesteps starting from the end - t_start = len(self.scheduler.timesteps) - num_inference_steps - timesteps = self.scheduler.timesteps[t_start:] - if hasattr(self.scheduler, "set_begin_index"): - self.scheduler.set_begin_index(t_start) - return timesteps, num_inference_steps - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = ( - batch_size, - num_channels_latents, - int(height) // self.vae_scale_factor, - int(width) // self.vae_scale_factor, - ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents - # YiYi TODO: refactor using _encode_vae_image - def prepare_latents_img2img( - self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True - ): - if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): - raise ValueError( - f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" - ) - - # Offload text encoder if `enable_model_cpu_offload` was enabled - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.text_encoder_2.to("cpu") - torch.cuda.empty_cache() - - image = image.to(device=device, dtype=dtype) - - batch_size = batch_size * num_images_per_prompt - - if image.shape[1] == 4: - init_latents = image - - else: - latents_mean = latents_std = None - if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: - latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: - latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) - # make sure the VAE is in float32 mode, as it overflows in float16 - if self.vae.config.force_upcast: - image = image.float() - self.vae.to(dtype=torch.float32) - - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - elif isinstance(generator, list): - if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: - image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) - elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " - ) - - init_latents = [ - retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(batch_size) - ] - init_latents = torch.cat(init_latents, dim=0) - else: - init_latents = retrieve_latents(self.vae.encode(image), generator=generator) - - if self.vae.config.force_upcast: - self.vae.to(dtype) - - init_latents = init_latents.to(dtype) - if latents_mean is not None and latents_std is not None: - latents_mean = latents_mean.to(device=device, dtype=dtype) - latents_std = latents_std.to(device=device, dtype=dtype) - init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std - else: - init_latents = self.vae.config.scaling_factor * init_latents - - if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: - # expand init_latents for batch_size - additional_image_per_prompt = batch_size // init_latents.shape[0] - init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) - elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." - ) - else: - init_latents = torch.cat([init_latents], dim=0) - - if add_noise: - shape = init_latents.shape - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - # get latents - init_latents = self.scheduler.add_noise(init_latents, noise, timestep) - - latents = init_latents - - return latents - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents - def prepare_latents_inpaint( - self, - batch_size, - num_channels_latents, - height, - width, - dtype, - device, - generator, - latents=None, - image=None, - timestep=None, - is_strength_max=True, - add_noise=True, - return_noise=False, - return_image_latents=False, - ): - shape = ( - batch_size, - num_channels_latents, - int(height) // self.vae_scale_factor, - int(width) // self.vae_scale_factor, - ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if (image is None or timestep is None) and not is_strength_max: - raise ValueError( - "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." - "However, either the image or the noise timestep has not been provided." - ) - - if image.shape[1] == 4: - image_latents = image.to(device=device, dtype=dtype) - image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) - elif return_image_latents or (latents is None and not is_strength_max): - image = image.to(device=device, dtype=dtype) - image_latents = self._encode_vae_image(image=image, generator=generator) - image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) - - if latents is None and add_noise: - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - # if strength is 1. then initialise the latents to noise, else initial to image + noise - latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) - # if pure noise then scale the initial latents by the Scheduler's init sigma - latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents - elif add_noise: - noise = latents.to(device) - latents = noise * self.scheduler.init_noise_sigma - else: - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = image_latents.to(device) - - outputs = (latents,) - - if return_noise: - outputs += (noise,) - - if return_image_latents: - outputs += (image_latents,) - - return outputs - - - # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image - # YiYi TODO: update the _encode_vae_image so that we can use #Coped from - def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): - - latents_mean = latents_std = None - if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: - latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: - latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) - - dtype = image.dtype - if self.vae.config.force_upcast: - image = image.float() - self.vae.to(dtype=torch.float32) - - if isinstance(generator, list): - image_latents = [ - retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(image.shape[0]) - ] - image_latents = torch.cat(image_latents, dim=0) - else: - image_latents = retrieve_latents(self.vae.encode(image), generator=generator) - - if self.vae.config.force_upcast: - self.vae.to(dtype) - - image_latents = image_latents.to(dtype) - if latents_mean is not None and latents_std is not None: - latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) - latents_std = latents_std.to(device=image_latents.device, dtype=dtype) - image_latents = (image_latents - latents_mean) * self.vae.config.scaling_factor / latents_std - else: - image_latents = self.vae.config.scaling_factor * image_latents - - return image_latents - - - # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents - # do not accept do_classifier_free_guidance - def prepare_mask_latents( - self, mask, masked_image, batch_size, height, width, dtype, device, generator - ): - # resize the mask to latents shape as we concatenate the mask to the latents - # we do that before converting to dtype to avoid breaking in case we're using cpu_offload - # and half precision - mask = torch.nn.functional.interpolate( - mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) - ) - mask = mask.to(device=device, dtype=dtype) - - # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method - if mask.shape[0] < batch_size: - if not batch_size % mask.shape[0] == 0: - raise ValueError( - "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" - f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" - " of masks that you pass is divisible by the total requested batch size." - ) - mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) - - if masked_image is not None and masked_image.shape[1] == 4: - masked_image_latents = masked_image - else: - masked_image_latents = None - - if masked_image is not None: - if masked_image_latents is None: - masked_image = masked_image.to(device=device, dtype=dtype) - masked_image_latents = self._encode_vae_image(masked_image, generator=generator) - - if masked_image_latents.shape[0] < batch_size: - if not batch_size % masked_image_latents.shape[0] == 0: - raise ValueError( - "The passed images and the required batch size don't match. Images are supposed to be duplicated" - f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." - " Make sure the number of images that you pass is divisible by the total requested batch size." - ) - masked_image_latents = masked_image_latents.repeat( - batch_size // masked_image_latents.shape[0], 1, 1, 1 - ) - - # aligning device to prevent device errors when concating it with the latent model input - masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) - - return mask, masked_image_latents - - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae - def upcast_vae(self): - dtype = self.vae.dtype - self.vae.to(dtype=torch.float32) - use_torch_2_0_or_xformers = isinstance( - self.vae.decoder.mid_block.attentions[0].processor, - ( - AttnProcessor2_0, - XFormersAttnProcessor, - ), - ) - # if xformers or torch_2_0 is used attention block does not need - # to be in float32 which can save lots of memory - if use_torch_2_0_or_xformers: - self.vae.post_quant_conv.to(dtype) - self.vae.decoder.conv_in.to(dtype) - self.vae.decoder.mid_block.to(dtype) - - # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding - def get_guidance_scale_embedding( - self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 - ) -> torch.Tensor: - """ - See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 - - Args: - w (`torch.Tensor`): - Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. - embedding_dim (`int`, *optional*, defaults to 512): - Dimension of the embeddings to generate. - dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): - Data type of the generated embeddings. - - Returns: - `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. - """ - assert len(w.shape) == 1 - w = w * 1000.0 - - half_dim = embedding_dim // 2 - emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) - emb = w.to(dtype)[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0, 1)) - assert emb.shape == (w.shape[0], embedding_dim) - return emb diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index b94b9ad4a7e3..0d28cb81af38 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1388,7 +1388,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class ModularPipeline(metaclass=DummyObject): +class ModularLoader(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 3e7c3a735ee9..a512b107cf96 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2432,7 +2432,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class StableDiffusionXLModularPipeline(metaclass=DummyObject): +class StableDiffusionXLModularLoader(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): diff --git a/src/diffusers/utils/dynamic_modules_utils.py b/src/diffusers/utils/dynamic_modules_utils.py index 5d0752af8983..5d5eb23969ab 100644 --- a/src/diffusers/utils/dynamic_modules_utils.py +++ b/src/diffusers/utils/dynamic_modules_utils.py @@ -15,13 +15,16 @@ """Utilities to dynamically load objects from the Hub.""" import importlib +import signal import inspect import json import os import re import shutil import sys +import threading from pathlib import Path +from types import ModuleType from typing import Dict, Optional, Union from urllib import request @@ -37,6 +40,8 @@ # See https://huggingface.co/datasets/diffusers/community-pipelines-mirror COMMUNITY_PIPELINES_MIRROR_ID = "diffusers/community-pipelines-mirror" +TIME_OUT_REMOTE_CODE = int(os.getenv("DIFFUSERS_TIMEOUT_REMOTE_CODE", 15)) +_HF_REMOTE_CODE_LOCK = threading.Lock() def get_diffusers_versions(): @@ -154,15 +159,87 @@ def check_imports(filename): return get_relative_imports(filename) -def get_class_in_module(class_name, module_path): +def _raise_timeout_error(signum, frame): + raise ValueError( + "Loading this model requires you to execute custom code contained in the model repository on your local " + "machine. Please set the option `trust_remote_code=True` to permit loading of this model." + ) + + +def resolve_trust_remote_code(trust_remote_code, model_name, has_remote_code): + if trust_remote_code is None: + if has_remote_code and TIME_OUT_REMOTE_CODE > 0: + prev_sig_handler = None + try: + prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error) + signal.alarm(TIME_OUT_REMOTE_CODE) + while trust_remote_code is None: + answer = input( + f"The repository for {model_name} contains custom code which must be executed to correctly " + f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n" + f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n" + f"Do you wish to run the custom code? [y/N] " + ) + if answer.lower() in ["yes", "y", "1"]: + trust_remote_code = True + elif answer.lower() in ["no", "n", "0", ""]: + trust_remote_code = False + signal.alarm(0) + except Exception: + # OS which does not support signal.SIGALRM + raise ValueError( + f"The repository for {model_name} contains custom code which must be executed to correctly " + f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n" + f"Please pass the argument `trust_remote_code=True` to allow custom code to be run." + ) + finally: + if prev_sig_handler is not None: + signal.signal(signal.SIGALRM, prev_sig_handler) + signal.alarm(0) + elif has_remote_code: + # For the CI which puts the timeout at 0 + _raise_timeout_error(None, None) + + if has_remote_code and not trust_remote_code: + raise ValueError( + f"Loading {model_name} requires you to execute the configuration file in that" + " repo on your local machine. Make sure you have read the code there to avoid malicious use, then" + " set the option `trust_remote_code=True` to remove this error." + ) + + return trust_remote_code + + +def get_class_in_module(class_name, module_path, force_reload=False): """ Import a module on the cache directory for modules and extract a class from it. """ - module_path = module_path.replace(os.path.sep, ".") - module = importlib.import_module(module_path) + name = os.path.normpath(module_path) + if name.endswith(".py"): + name = name[:-3] + name = name.replace(os.path.sep, ".") + module_file: Path = Path(HF_MODULES_CACHE) / module_path + + with _HF_REMOTE_CODE_LOCK: + if force_reload: + sys.modules.pop(name, None) + importlib.invalidate_caches() + cached_module: Optional[ModuleType] = sys.modules.get(name) + module_spec = importlib.util.spec_from_file_location(name, location=module_file) + + module: ModuleType + if cached_module is None: + module = importlib.util.module_from_spec(module_spec) + # insert it into sys.modules before any loading begins + sys.modules[name] = module + else: + module = cached_module + + module_spec.loader.exec_module(module) if class_name is None: return find_pipeline_class(module) + return getattr(module, class_name) @@ -454,4 +531,4 @@ def get_class_from_dynamic_module( revision=revision, local_files_only=local_files_only, ) - return get_class_in_module(class_name, final_module.replace(".py", "")) + return get_class_in_module(class_name, final_module) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index a5df07e4a3c2..622c0d124f97 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -90,6 +90,11 @@ def is_compiled_module(module) -> bool: return isinstance(module, torch._dynamo.eval_frame.OptimizedModule) +def unwrap_module(module): + """Unwraps a module if it was compiled with torch.compile()""" + return module._orig_mod if is_compiled_module(module) else module + + def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor": """Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497). From 7ad01a6350c2c430690bf00b892a894aed305eec Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 20 Jun 2025 07:23:14 +0200 Subject: [PATCH 066/170] rename modular_pipeline_block_mappings.py to modular_block_mapping --- ..._pipeline_block_mappings.py => modular_block_mappings.py} | 5 ----- 1 file changed, 5 deletions(-) rename src/diffusers/modular_pipelines/stable_diffusion_xl/{modular_pipeline_block_mappings.py => modular_block_mappings.py} (96%) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_block_mappings.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_block_mappings.py similarity index 96% rename from src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_block_mappings.py rename to src/diffusers/modular_pipelines/stable_diffusion_xl/modular_block_mappings.py index 00cd5ca3735a..4ffd685df044 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_block_mappings.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_block_mappings.py @@ -106,11 +106,6 @@ ("decode", StableDiffusionXLAutoDecodeStep) ]) -AUTO_CORE_BLOCKS = InsertableOrderedDict([ - ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), - ("denoise", StableDiffusionXLAutoDenoiseStep), -]) - SDXL_SUPPORTED_BLOCKS = { "text2img": TEXT2IMAGE_BLOCKS, From 5a8c1b5f19fd053b21e9edd72b1da8d7c5f2482c Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 20 Jun 2025 07:24:14 +0200 Subject: [PATCH 067/170] add block mappings to modular_diffusers.stable_diffusion_xl.__init__ --- src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py index f3f961d61a13..1fbc141ac3de 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py @@ -25,6 +25,7 @@ _import_structure["modular_loader"] = ["StableDiffusionXLModularLoader"] _import_structure["encoders"] = ["StableDiffusionXLAutoIPAdapterStep", "StableDiffusionXLTextEncoderStep", "StableDiffusionXLAutoVaeEncoderStep"] _import_structure["decoders"] = ["StableDiffusionXLAutoDecodeStep"] + _import_structure["modular_block_mappings"] = ["TEXT2IMAGE_BLOCKS", "IMAGE2IMAGE_BLOCKS", "INPAINT_BLOCKS", "CONTROLNET_BLOCKS", "CONTROLNET_UNION_BLOCKS", "IP_ADAPTER_BLOCKS", "AUTO_BLOCKS", "SDXL_SUPPORTED_BLOCKS"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -37,6 +38,7 @@ from .modular_loader import StableDiffusionXLModularLoader from .encoders import StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoVaeEncoderStep from .decoders import StableDiffusionXLAutoDecodeStep + from .modular_block_mappings import SDXL_SUPPORTED_BLOCKS, TEXT2IMAGE_BLOCKS, IMAGE2IMAGE_BLOCKS, INPAINT_BLOCKS, CONTROLNET_BLOCKS, CONTROLNET_UNION_BLOCKS, IP_ADAPTER_BLOCKS, AUTO_BLOCKS else: import sys From 8913d59bf3470996eb1c5ca0291b7f71c3a0394b Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 20 Jun 2025 07:25:20 +0200 Subject: [PATCH 068/170] add to method to modular loader, copied from DiffusionPipeline, not tested yet --- .../modular_pipelines/modular_pipeline.py | 193 +++++++++++++++++- 1 file changed, 190 insertions(+), 3 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 84b9b594d758..ca5932fafedc 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -19,6 +19,7 @@ from collections import OrderedDict from dataclasses import dataclass, field from typing import Any, Dict, List, Tuple, Union, Optional +from typing_extensions import Self from copy import deepcopy @@ -2012,9 +2013,195 @@ def load(self, component_names: Optional[List[str]] = None, **kwargs): # Register all components at once self.register_components(**components_to_register) - # YiYi TODO: should support to method - def to(self, *args, **kwargs): - pass + # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline.to + def to(self, *args, **kwargs) -> Self: + r""" + Performs Pipeline dtype and/or device conversion. A torch.dtype and torch.device are inferred from the + arguments of `self.to(*args, **kwargs).` + + + + If the pipeline already has the correct torch.dtype and torch.device, then it is returned as is. Otherwise, + the returned pipeline is a copy of self with the desired torch.dtype and torch.device. + + + + + Here are the ways to call `to`: + + - `to(dtype, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified + [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype) + - `to(device, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified + [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) + - `to(device=None, dtype=None, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the + specified [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) and + [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype) + + Arguments: + dtype (`torch.dtype`, *optional*): + Returns a pipeline with the specified + [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype) + device (`torch.Device`, *optional*): + Returns a pipeline with the specified + [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) + silence_dtype_warnings (`str`, *optional*, defaults to `False`): + Whether to omit warnings if the target `dtype` is not compatible with the target `device`. + + Returns: + [`DiffusionPipeline`]: The pipeline converted to specified `dtype` and/or `dtype`. + """ + dtype = kwargs.pop("dtype", None) + device = kwargs.pop("device", None) + silence_dtype_warnings = kwargs.pop("silence_dtype_warnings", False) + + dtype_arg = None + device_arg = None + if len(args) == 1: + if isinstance(args[0], torch.dtype): + dtype_arg = args[0] + else: + device_arg = torch.device(args[0]) if args[0] is not None else None + elif len(args) == 2: + if isinstance(args[0], torch.dtype): + raise ValueError( + "When passing two arguments, make sure the first corresponds to `device` and the second to `dtype`." + ) + device_arg = torch.device(args[0]) if args[0] is not None else None + dtype_arg = args[1] + elif len(args) > 2: + raise ValueError("Please make sure to pass at most two arguments (`device` and `dtype`) `.to(...)`") + + if dtype is not None and dtype_arg is not None: + raise ValueError( + "You have passed `dtype` both as an argument and as a keyword argument. Please only pass one of the two." + ) + + dtype = dtype or dtype_arg + + if device is not None and device_arg is not None: + raise ValueError( + "You have passed `device` both as an argument and as a keyword argument. Please only pass one of the two." + ) + + device = device or device_arg + device_type = torch.device(device).type if device is not None else None + pipeline_has_bnb = any(any((_check_bnb_status(module))) for _, module in self.components.items()) + + # throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU. + def module_is_sequentially_offloaded(module): + if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"): + return False + + _, _, is_loaded_in_8bit_bnb = _check_bnb_status(module) + + if is_loaded_in_8bit_bnb: + return False + + return hasattr(module, "_hf_hook") and ( + isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook) + or hasattr(module._hf_hook, "hooks") + and isinstance(module._hf_hook.hooks[0], accelerate.hooks.AlignDevicesHook) + ) + + def module_is_offloaded(module): + if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"): + return False + + return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload) + + # .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer + pipeline_is_sequentially_offloaded = any( + module_is_sequentially_offloaded(module) for _, module in self.components.items() + ) + + is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1 + if is_pipeline_device_mapped: + raise ValueError( + "It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline." + ) + + if device_type in ["cuda", "xpu"]: + if pipeline_is_sequentially_offloaded and not pipeline_has_bnb: + raise ValueError( + "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading." + ) + # PR: https://github.com/huggingface/accelerate/pull/3223/ + elif pipeline_has_bnb and is_accelerate_version("<", "1.1.0.dev0"): + raise ValueError( + "You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation." + ) + + # Display a warning in this case (the operation succeeds but the benefits are lost) + pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items()) + if pipeline_is_offloaded and device_type in ["cuda", "xpu"]: + logger.warning( + f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading." + ) + + # Enable generic support for Intel Gaudi accelerator using GPU/HPU migration + if device_type == "hpu" and kwargs.pop("hpu_migration", True) and is_hpu_available(): + os.environ["PT_HPU_GPU_MIGRATION"] = "1" + logger.debug("Environment variable set: PT_HPU_GPU_MIGRATION=1") + + import habana_frameworks.torch # noqa: F401 + + # HPU hardware check + if not (hasattr(torch, "hpu") and torch.hpu.is_available()): + raise ValueError("You are trying to call `.to('hpu')` but HPU device is unavailable.") + + os.environ["PT_HPU_MAX_COMPOUND_OP_SIZE"] = "1" + logger.debug("Environment variable set: PT_HPU_MAX_COMPOUND_OP_SIZE=1") + + module_names, _ = self._get_signature_keys(self) + modules = [getattr(self, n, None) for n in module_names] + modules = [m for m in modules if isinstance(m, torch.nn.Module)] + + is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded + for module in modules: + _, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(module) + is_group_offloaded = self._maybe_raise_error_if_group_offload_active(module=module) + + if (is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb) and dtype is not None: + logger.warning( + f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {'4bit' if is_loaded_in_4bit_bnb else '8bit'} and conversion to {dtype} is not supported. Module is still in {'4bit' if is_loaded_in_4bit_bnb else '8bit'} precision." + ) + + if is_loaded_in_8bit_bnb and device is not None: + logger.warning( + f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}." + ) + + # Note: we also handle this at the ModelMixin level. The reason for doing it here too is that modeling + # components can be from outside diffusers too, but still have group offloading enabled. + if ( + self._maybe_raise_error_if_group_offload_active(raise_error=False, module=module) + and device is not None + ): + logger.warning( + f"The module '{module.__class__.__name__}' is group offloaded and moving it to {device} via `.to()` is not supported." + ) + + # This can happen for `transformer` models. CPU placement was added in + # https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly. + if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"): + module.to(device=device) + elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb and not is_group_offloaded: + module.to(device, dtype) + + if ( + module.dtype == torch.float16 + and str(device) in ["cpu"] + and not silence_dtype_warnings + and not is_offloaded + ): + logger.warning( + "Pipelines loaded with `dtype=torch.float16` cannot run with `cpu` device. It" + " is not recommended to move them to `cpu` as running them will fail. Please make" + " sure to use an accelerator to run the pipeline in inference, due to the lack of" + " support for`float16` operations on this device in PyTorch. Please, remove the" + " `torch_dtype=torch.float16` argument, or use another device for inference." + ) + return self # YiYi TODO: # 1. should support save some components too! currently only modular_model_index.json is saved From 45392cce11c566a7df7aeeddf845dbd24a3b2311 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 20 Jun 2025 07:38:21 +0200 Subject: [PATCH 069/170] update the description of StableDiffusionXLDenoiseLoopWrapper --- .../stable_diffusion_xl/denoise.py | 38 +++++++++++-------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py index 4d7ab12cf009..4485c17e97d4 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py @@ -687,9 +687,11 @@ class StableDiffusionXLDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): def description(self) -> str: return ( "Denoise step that iteratively denoise the latents. " - "Its loop logic is defined in parent class `StableDiffusionXLDenoiseLoopWrapper` " - "and at each iteration, it runs blocks defined in `blocks` sequencially, i.e. `StableDiffusionXLDenoiseLoopBeforeDenoiser` and `StableDiffusionXLDenoiseLoopDenoiser`, " - "and finally `StableDiffusionXLDenoiseLoopAfterDenoiser` to update the latents." + "Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n" + "and at each iteration, it runs blocks defined in `blocks` sequencially:\n" + " - `StableDiffusionXLDenoiseLoopBeforeDenoiser`\n" + " - `StableDiffusionXLDenoiseLoopDenoiser`\n" + " - `StableDiffusionXLDenoiseLoopAfterDenoiser`\n" ) # control_cond @@ -699,10 +701,12 @@ class StableDiffusionXLControlNetDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper @property def description(self) -> str: return ( - "Denoise step that iteratively denoise the latents with controlnet. " - "Its loop logic is defined in parent class `StableDiffusionXLDenoiseLoopWrapper` " - "and at each iteration, it runs blocks defined in `blocks` sequencially, i.e. `StableDiffusionXLDenoiseLoopBeforeDenoiser` and `StableDiffusionXLControlNetDenoiseLoopDenoiser`, " - "and finally `StableDiffusionXLDenoiseLoopAfterDenoiser` to update the latents." + "Denoise step that iteratively denoise the latents with controlnet. \n" + "Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n" + "and at each iteration, it runs blocks defined in `blocks` sequencially:\n" + " - `StableDiffusionXLDenoiseLoopBeforeDenoiser`\n" + " - `StableDiffusionXLControlNetDenoiseLoopDenoiser`\n" + " - `StableDiffusionXLDenoiseLoopAfterDenoiser`\n" ) # mask @@ -712,10 +716,12 @@ class StableDiffusionXLInpaintDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): @property def description(self) -> str: return ( - "Denoise step that iteratively denoise the latents(for inpainting task only). " - "Its loop logic is defined in parent class `StableDiffusionXLDenoiseLoopWrapper` " - "and at each iteration, it runs blocks defined in `blocks` sequencially, i.e. `StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser` and `StableDiffusionXLDenoiseLoopDenoiser`, " - "and finally `StableDiffusionXLInpaintDenoiseLoopAfterDenoiser` to update the latents." + "Denoise step that iteratively denoise the latents(for inpainting task only). \n" + "Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n" + "and at each iteration, it runs blocks defined in `blocks` sequencially:\n" + " - `StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser`\n" + " - `StableDiffusionXLDenoiseLoopDenoiser`\n" + " - `StableDiffusionXLInpaintDenoiseLoopAfterDenoiser`\n" ) # control_cond + mask class StableDiffusionXLInpaintControlNetDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): @@ -724,10 +730,12 @@ class StableDiffusionXLInpaintControlNetDenoiseLoop(StableDiffusionXLDenoiseLoop @property def description(self) -> str: return ( - "Denoise step that iteratively denoise the latents(for inpainting task only) with controlnet. " - "Its loop logic is defined in parent class `StableDiffusionXLDenoiseLoopWrapper` " - "and at each iteration, it runs blocks defined in `blocks` sequencially, i.e. `StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser` and `StableDiffusionXLControlNetDenoiseLoopDenoiser`, " - "and finally `StableDiffusionXLInpaintDenoiseLoopAfterDenoiser` to update the latents." + "Denoise step that iteratively denoise the latents(for inpainting task only) with controlnet. \n" + "Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n" + "and at each iteration, it runs blocks defined in `blocks` sequencially:\n" + " - `StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser`\n" + " - `StableDiffusionXLControlNetDenoiseLoopDenoiser`\n" + " - `StableDiffusionXLInpaintDenoiseLoopAfterDenoiser`\n" ) From 9e58856b7a80cf995307be62ef6dca839ffa11c6 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 21 Jun 2025 04:24:44 +0200 Subject: [PATCH 070/170] add __repr__ method for InsertableOrderedDict --- .../modular_pipelines/modular_pipeline_utils.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index ced059551f9a..868c09106043 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -41,6 +41,16 @@ def insert(self, key, value, index): # Return self for method chaining return self + + def __repr__(self): + if not self: + return "InsertableOrderedDict()" + + items = [] + for i, (key, value) in enumerate(self.items()): + items.append(f"{i}: ({repr(key)}, {repr(value)})") + + return "InsertableOrderedDict([\n " + ",\n ".join(items) + "\n])" # YiYi TODO: From 04c16d0a56a04682a9e4a5128966d80d89a13372 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 21 Jun 2025 04:25:12 +0200 Subject: [PATCH 071/170] update --- .../stable_diffusion_xl/denoise.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py index 4485c17e97d4..045b6968aa37 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py @@ -54,7 +54,7 @@ def expected_components(self) -> List[ComponentSpec]: @property def description(self) -> str: - return "step within the denoising loop that prepare the latent input for the denoiser. Used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object, e.g. `StableDiffusionXLDenoiseLoopWrapper`" + return "step within the denoising loop that prepare the latent input for the denoiser. This block should be used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)" @property @@ -89,7 +89,7 @@ def expected_components(self) -> List[ComponentSpec]: @property def description(self) -> str: - return "step within the denoising loop that prepare the latent input for the denoiser (for inpainting workflow only). Used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object, e.g. `StableDiffusionXLDenoiseLoopWrapper`" + return "step within the denoising loop that prepare the latent input for the denoiser (for inpainting workflow only). This block should be used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object" @property @@ -165,7 +165,7 @@ def expected_components(self) -> List[ComponentSpec]: @property def description(self) -> str: return ( - "Step within the denoising loop that denoise the latents with guidance. Used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object, e.g. `StableDiffusionXLDenoiseLoopWrapper`" + "Step within the denoising loop that denoise the latents with guidance. This block should be used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)" ) @property @@ -269,7 +269,7 @@ def expected_components(self) -> List[ComponentSpec]: @property def description(self) -> str: - return "step within the denoising loop that denoise the latents with guidance (with controlnet). Used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object, e.g. `StableDiffusionXLDenoiseLoopWrapper`" + return "step within the denoising loop that denoise the latents with guidance (with controlnet). This block should be used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)" @property def inputs(self) -> List[Tuple[str, Any]]: @@ -458,7 +458,7 @@ def expected_components(self) -> List[ComponentSpec]: @property def description(self) -> str: - return "step within the denoising loop that update the latents. Used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object, e.g. `StableDiffusionXLDenoiseLoopWrapper`" + return "step within the denoising loop that update the latents. This block should be used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)" @property def inputs(self) -> List[Tuple[str, Any]]: @@ -521,7 +521,7 @@ def expected_components(self) -> List[ComponentSpec]: @property def description(self) -> str: - return "step within the denoising loop that update the latents (for inpainting workflow only). Used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object, e.g. `StableDiffusionXLDenoiseLoopWrapper`" + return "step within the denoising loop that update the latents (for inpainting workflow only). This block should be used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)" @property def inputs(self) -> List[Tuple[str, Any]]: @@ -686,9 +686,9 @@ class StableDiffusionXLDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): @property def description(self) -> str: return ( - "Denoise step that iteratively denoise the latents. " + "Denoise step that iteratively denoise the latents. \n" "Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n" - "and at each iteration, it runs blocks defined in `blocks` sequencially:\n" + "At each iteration, it runs blocks defined in `blocks` sequencially:\n" " - `StableDiffusionXLDenoiseLoopBeforeDenoiser`\n" " - `StableDiffusionXLDenoiseLoopDenoiser`\n" " - `StableDiffusionXLDenoiseLoopAfterDenoiser`\n" @@ -703,7 +703,7 @@ def description(self) -> str: return ( "Denoise step that iteratively denoise the latents with controlnet. \n" "Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n" - "and at each iteration, it runs blocks defined in `blocks` sequencially:\n" + "At each iteration, it runs blocks defined in `blocks` sequencially:\n" " - `StableDiffusionXLDenoiseLoopBeforeDenoiser`\n" " - `StableDiffusionXLControlNetDenoiseLoopDenoiser`\n" " - `StableDiffusionXLDenoiseLoopAfterDenoiser`\n" @@ -718,7 +718,7 @@ def description(self) -> str: return ( "Denoise step that iteratively denoise the latents(for inpainting task only). \n" "Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n" - "and at each iteration, it runs blocks defined in `blocks` sequencially:\n" + "At each iteration, it runs blocks defined in `blocks` sequencially:\n" " - `StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser`\n" " - `StableDiffusionXLDenoiseLoopDenoiser`\n" " - `StableDiffusionXLInpaintDenoiseLoopAfterDenoiser`\n" @@ -732,7 +732,7 @@ def description(self) -> str: return ( "Denoise step that iteratively denoise the latents(for inpainting task only) with controlnet. \n" "Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n" - "and at each iteration, it runs blocks defined in `blocks` sequencially:\n" + "At each iteration, it runs blocks defined in `blocks` sequencially:\n" " - `StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser`\n" " - `StableDiffusionXLControlNetDenoiseLoopDenoiser`\n" " - `StableDiffusionXLInpaintDenoiseLoopAfterDenoiser`\n" From 083479c3656ecf22439854429e57f308fe84888b Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 21 Jun 2025 04:28:10 +0200 Subject: [PATCH 072/170] ordereddict -> insertableOrderedDict; make sure loader to method works --- .../modular_pipelines/modular_pipeline.py | 47 ++++++++++++++++--- 1 file changed, 40 insertions(+), 7 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index ca5932fafedc..43505aabee23 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -48,6 +48,7 @@ format_inputs_short, format_intermediates_short, make_doc_string, + InsertableOrderedDict ) from .components_manager import ComponentsManager from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code @@ -66,6 +67,7 @@ ) + @dataclass class PipelineState: """ @@ -622,7 +624,7 @@ class AutoPipelineBlocks(ModularPipelineBlocks): block_trigger_inputs = [] def __init__(self): - blocks = OrderedDict() + blocks = InsertableOrderedDict() for block_name, block_cls in zip(self.block_names, self.block_classes): blocks[block_name] = block_cls() self.blocks = blocks @@ -958,7 +960,7 @@ def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "SequentialPipelineBlo return instance def __init__(self): - blocks = OrderedDict() + blocks = InsertableOrderedDict() for block_name, block_cls in zip(self.block_names, self.block_classes): blocks[block_name] = block_cls() self.blocks = blocks @@ -1449,7 +1451,7 @@ def outputs(self) -> List[str]: def __init__(self): - blocks = OrderedDict() + blocks = InsertableOrderedDict() for block_name, block_cls in zip(self.block_names, self.block_classes): blocks[block_name] = block_cls() self.blocks = blocks @@ -1662,6 +1664,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin): """ config_name = "modular_model_index.json" + hf_device_map = None def register_components(self, **kwargs): @@ -2013,7 +2016,26 @@ def load(self, component_names: Optional[List[str]] = None, **kwargs): # Register all components at once self.register_components(**components_to_register) - # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline.to + # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline._maybe_raise_error_if_group_offload_active + def _maybe_raise_error_if_group_offload_active( + self, raise_error: bool = False, module: Optional[torch.nn.Module] = None + ) -> bool: + from ..hooks.group_offloading import _is_group_offload_enabled + + components = self.components.values() if module is None else [module] + components = [component for component in components if isinstance(component, torch.nn.Module)] + for component in components: + if _is_group_offload_enabled(component): + if raise_error: + raise ValueError( + "You are trying to apply model/sequential CPU offloading to a pipeline that contains components " + "with group offloading enabled. This is not supported. Please disable group offloading for " + "components of the pipeline to use other offloading methods." + ) + return True + return False + + # Modified from diffusers.pipelines.pipeline_utils.DiffusionPipeline.to def to(self, *args, **kwargs) -> Self: r""" Performs Pipeline dtype and/or device conversion. A torch.dtype and torch.device are inferred from the @@ -2050,6 +2072,10 @@ def to(self, *args, **kwargs) -> Self: Returns: [`DiffusionPipeline`]: The pipeline converted to specified `dtype` and/or `dtype`. """ + from ..pipelines.pipeline_utils import _check_bnb_status, DiffusionPipeline + from ..utils import is_accelerate_available, is_accelerate_version, is_hpu_available, is_transformers_version + + dtype = kwargs.pop("dtype", None) device = kwargs.pop("device", None) silence_dtype_warnings = kwargs.pop("silence_dtype_warnings", False) @@ -2152,8 +2178,7 @@ def module_is_offloaded(module): os.environ["PT_HPU_MAX_COMPOUND_OP_SIZE"] = "1" logger.debug("Environment variable set: PT_HPU_MAX_COMPOUND_OP_SIZE=1") - module_names, _ = self._get_signature_keys(self) - modules = [getattr(self, n, None) for n in module_names] + modules = self.components.values() modules = [m for m in modules if isinstance(m, torch.nn.Module)] is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded @@ -2431,4 +2456,12 @@ def save_pretrained(self, save_directory: Optional[Union[str, os.PathLike]] = No @property def doc(self): - return self.blocks.doc \ No newline at end of file + return self.blocks.doc + + def to(self, *args, **kwargs): + self.loader.to(*args, **kwargs) + return self + + @property + def components(self): + return self.loader.components \ No newline at end of file From 4751d456f2e31475da148dcc587017a2b7a8f340 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 22 Jun 2025 12:31:16 +0200 Subject: [PATCH 073/170] shorten loop subblock name --- .../stable_diffusion_xl/denoise.py | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py index 045b6968aa37..3a8bca74b5a0 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py @@ -42,7 +42,7 @@ # YiYi experimenting composible denoise loop # loop step (1): prepare latent input for denoiser -class StableDiffusionXLDenoiseLoopBeforeDenoiser(PipelineBlock): +class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock): model_name = "stable-diffusion-xl" @@ -76,7 +76,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, block_state: Bloc return components, block_state # loop step (1): prepare latent input for denoiser (with inpainting) -class StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser(PipelineBlock): +class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock): model_name = "stable-diffusion-xl" @@ -147,7 +147,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, block_state: Bloc return components, block_state # loop step (2): denoise the latents with guidance -class StableDiffusionXLDenoiseLoopDenoiser(PipelineBlock): +class StableDiffusionXLLoopDenoiser(PipelineBlock): model_name = "stable-diffusion-xl" @@ -251,7 +251,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, block_state: Bloc return components, block_state # loop step (2): denoise the latents with guidance (with controlnet) -class StableDiffusionXLControlNetDenoiseLoopDenoiser(PipelineBlock): +class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock): model_name = "stable-diffusion-xl" @@ -446,7 +446,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, block_state: Bloc return components, block_state # loop step (3): scheduler step to update latents -class StableDiffusionXLDenoiseLoopAfterDenoiser(PipelineBlock): +class StableDiffusionXLLoopAfterDenoiser(PipelineBlock): model_name = "stable-diffusion-xl" @@ -508,7 +508,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, block_state: Bloc return components, block_state # loop step (3): scheduler step to update latents (with inpainting) -class StableDiffusionXLInpaintDenoiseLoopAfterDenoiser(PipelineBlock): +class StableDiffusionXLInpaintLoopAfterDenoiser(PipelineBlock): model_name = "stable-diffusion-xl" @@ -680,7 +680,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt # composing the denoising loops class StableDiffusionXLDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): - block_classes = [StableDiffusionXLDenoiseLoopBeforeDenoiser, StableDiffusionXLDenoiseLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser] + block_classes = [StableDiffusionXLLoopBeforeDenoiser, StableDiffusionXLLoopDenoiser, StableDiffusionXLLoopAfterDenoiser] block_names = ["before_denoiser", "denoiser", "after_denoiser"] @property @@ -689,14 +689,14 @@ def description(self) -> str: "Denoise step that iteratively denoise the latents. \n" "Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n" "At each iteration, it runs blocks defined in `blocks` sequencially:\n" - " - `StableDiffusionXLDenoiseLoopBeforeDenoiser`\n" - " - `StableDiffusionXLDenoiseLoopDenoiser`\n" - " - `StableDiffusionXLDenoiseLoopAfterDenoiser`\n" + " - `StableDiffusionXLLoopBeforeDenoiser`\n" + " - `StableDiffusionXLLoopDenoiser`\n" + " - `StableDiffusionXLLoopAfterDenoiser`\n" ) # control_cond class StableDiffusionXLControlNetDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): - block_classes = [StableDiffusionXLDenoiseLoopBeforeDenoiser, StableDiffusionXLControlNetDenoiseLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser] + block_classes = [StableDiffusionXLLoopBeforeDenoiser, StableDiffusionXLControlNetLoopDenoiser, StableDiffusionXLLoopAfterDenoiser] block_names = ["before_denoiser", "denoiser", "after_denoiser"] @property def description(self) -> str: @@ -704,14 +704,14 @@ def description(self) -> str: "Denoise step that iteratively denoise the latents with controlnet. \n" "Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n" "At each iteration, it runs blocks defined in `blocks` sequencially:\n" - " - `StableDiffusionXLDenoiseLoopBeforeDenoiser`\n" - " - `StableDiffusionXLControlNetDenoiseLoopDenoiser`\n" - " - `StableDiffusionXLDenoiseLoopAfterDenoiser`\n" + " - `StableDiffusionXLLoopBeforeDenoiser`\n" + " - `StableDiffusionXLControlNetLoopDenoiser`\n" + " - `StableDiffusionXLLoopAfterDenoiser`\n" ) # mask class StableDiffusionXLInpaintDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): - block_classes = [StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser, StableDiffusionXLDenoiseLoopDenoiser, StableDiffusionXLInpaintDenoiseLoopAfterDenoiser] + block_classes = [StableDiffusionXLInpaintLoopBeforeDenoiser, StableDiffusionXLLoopDenoiser, StableDiffusionXLInpaintLoopAfterDenoiser] block_names = ["before_denoiser", "denoiser", "after_denoiser"] @property def description(self) -> str: @@ -719,13 +719,13 @@ def description(self) -> str: "Denoise step that iteratively denoise the latents(for inpainting task only). \n" "Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n" "At each iteration, it runs blocks defined in `blocks` sequencially:\n" - " - `StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser`\n" - " - `StableDiffusionXLDenoiseLoopDenoiser`\n" - " - `StableDiffusionXLInpaintDenoiseLoopAfterDenoiser`\n" + " - `StableDiffusionXLInpaintLoopBeforeDenoiser`\n" + " - `StableDiffusionXLLoopDenoiser`\n" + " - `StableDiffusionXLInpaintLoopAfterDenoiser`\n" ) # control_cond + mask class StableDiffusionXLInpaintControlNetDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): - block_classes = [StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser, StableDiffusionXLControlNetDenoiseLoopDenoiser, StableDiffusionXLInpaintDenoiseLoopAfterDenoiser] + block_classes = [StableDiffusionXLInpaintLoopBeforeDenoiser, StableDiffusionXLControlNetLoopDenoiser, StableDiffusionXLInpaintLoopAfterDenoiser] block_names = ["before_denoiser", "denoiser", "after_denoiser"] @property def description(self) -> str: @@ -733,9 +733,9 @@ def description(self) -> str: "Denoise step that iteratively denoise the latents(for inpainting task only) with controlnet. \n" "Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n" "At each iteration, it runs blocks defined in `blocks` sequencially:\n" - " - `StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser`\n" - " - `StableDiffusionXLControlNetDenoiseLoopDenoiser`\n" - " - `StableDiffusionXLInpaintDenoiseLoopAfterDenoiser`\n" + " - `StableDiffusionXLInpaintLoopBeforeDenoiser`\n" + " - `StableDiffusionXLControlNetLoopDenoiser`\n" + " - `StableDiffusionXLInpaintLoopAfterDenoiser`\n" ) From d12531ddf7ac7ab9eff5ec4d24f4b54d43f98cb9 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 22 Jun 2025 12:32:04 +0200 Subject: [PATCH 074/170] lora: only remove hooks that we add back --- src/diffusers/loaders/lora_base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 850c4b2b4bc1..08a64c348784 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -456,7 +456,8 @@ def _func_optionally_disable_offloading(_pipeline): logger.info( "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." ) - remove_hook_from_module(component, recurse=is_sequential_cpu_offload) + if is_sequential_cpu_offload or is_model_cpu_offload: + remove_hook_from_module(component, recurse=is_sequential_cpu_offload) return (is_model_cpu_offload, is_sequential_cpu_offload) From 19545fd3e148067740da9fc1977edbc56cfc0e30 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 22 Jun 2025 12:59:19 +0200 Subject: [PATCH 075/170] update components manager __repr__ --- src/diffusers/modular_pipelines/components_manager.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/diffusers/modular_pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py index 992353389b95..bdc24d474a32 100644 --- a/src/diffusers/modular_pipelines/components_manager.py +++ b/src/diffusers/modular_pipelines/components_manager.py @@ -703,7 +703,7 @@ def format_device(component, info): col_widths = { "name": max(15, max(len(name) for name in simple_names)), "class": max(25, max(len(component.__class__.__name__) for component in self.components.values())), - "device": 15, # Reduced since using more compact format + "device": 20, "dtype": 15, "size": 10, "load_id": max_load_id_len, @@ -725,7 +725,7 @@ def format_device(component, info): output += "Models:\n" + dash_line # Column headers output += f"{'Name':<{col_widths['name']}} | {'Class':<{col_widths['class']}} | " - output += f"{'Device':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | " + output += f"{'Device: act(exec)':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | " output += f"{'Size (GB)':<{col_widths['size']}} | {'Load ID':<{col_widths['load_id']}} | Collection\n" output += dash_line @@ -790,7 +790,6 @@ def format_device(component, info): output += f" Adapters: {info['adapters']}\n" if info.get("ip_adapter"): output += f" IP-Adapter: Enabled\n" - output += f" Added Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(info['added_time']))}\n" return output From 78d2454c7cd54dc6752ef37b6691a4db798c41c1 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 23 Jun 2025 16:06:17 +0200 Subject: [PATCH 076/170] fix --- src/diffusers/modular_pipelines/modular_pipeline.py | 2 +- .../stable_diffusion_xl/before_denoise.py | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 43505aabee23..c26a9c7c8a76 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -947,7 +947,7 @@ def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "SequentialPipelineBlo instance = cls() # Create instances if classes are provided - blocks = {} + blocks = InsertableOrderedDict() for name, block in blocks_dict.items(): if inspect.isclass(block): blocks[name] = block() diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py index 07f096249c0d..f6ff33967512 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py @@ -1698,7 +1698,13 @@ class StableDiffusionXLControlNetAutoInput(AutoPipelineBlocks): block_classes = [StableDiffusionXLControlNetUnionInputStep, StableDiffusionXLControlNetInputStep] block_names = ["controlnet_union", "controlnet"] block_trigger_inputs = ["control_mode", "control_image"] - + + @property + def description(self): + return "Controlnet Input step that prepare the controlnet input.\n" + \ + "This is an auto pipeline block that works for both controlnet and controlnet_union.\n" + \ + " - `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided.\n" + \ + " - `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided." # Before denoise From 085ade03bef6437dd81bb0943f9db9edee76ad0c Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 23 Jun 2025 16:12:31 +0200 Subject: [PATCH 077/170] add doc (developer guide) --- docs/source/en/_toctree.yml | 3 + .../en/modular_diffusers/developer_guide.md | 689 ++++++++++++++++++ 2 files changed, 692 insertions(+) create mode 100644 docs/source/en/modular_diffusers/developer_guide.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 4e62f3ef6182..d8fc32f38093 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -86,6 +86,9 @@ - local: hybrid_inference/api_reference title: API Reference title: Hybrid Inference +- sections: + - local: modular_diffusers/developer_guide + title: Developer Guide - sections: - local: using-diffusers/cogvideox title: CogVideoX diff --git a/docs/source/en/modular_diffusers/developer_guide.md b/docs/source/en/modular_diffusers/developer_guide.md new file mode 100644 index 000000000000..175383f0d0ca --- /dev/null +++ b/docs/source/en/modular_diffusers/developer_guide.md @@ -0,0 +1,689 @@ +# Developer Guide: Building with Modular Diffusers + +To implement new pipelines in the modular framework, you can use this process: + +#### 1. **Start with an existing pipeline as a base** + - Identify which existing pipeline is most similar to your target + - Determine what part of the pipeline need modification + +#### 2. **Build a working pipeline structure first** + - Assemble the complete pipeline structure + - Use existing blocks wherever possible + - For new blocks, create placeholders (e.g. you can copy from similar blocks and change the name) without implementing custom logic just yet + +#### 3. **Set up an example and test incrementally** + - Create a simple inference script with expected inputs/outputs + - Test incrementally as you implement changes + +Let's see how this works with the Differential Diffusion example. + + +## Differential Diffusion Pipeline + +Differential diffusion (https://differential-diffusion.github.io/) is an image-to-image workflow, so it makes sense for us to start with the preset of pipeline blocks used to build img2img pipeline (`IMAGE2IMAGE_BLOCKS`) and see how we can build this new pipeline with them. + +```python +IMAGE2IMAGE_BLOCKS = InsertableOrderedDict([ + ("text_encoder", StableDiffusionXLTextEncoderStep), + ("image_encoder", StableDiffusionXLVaeEncoderStep), + ("input", StableDiffusionXLInputStep), + ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), + ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep), + ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), + ("denoise", StableDiffusionXLDenoiseLoop), + ("decode", StableDiffusionXLDecodeStep) +]) +``` + +Note that "denoise" (`StableDiffusionXLDenoiseLoop`) is a loop that contains 3 loop blocks (more on SequentialLoopBlocks [here](https://colab.research.google.com/drive/1iVRjy_tOfmmm4gd0iVe0_Rl3c6cBzVqi?usp=sharing)) + +```python +denoise_blocks = IMAGE2IMAGE_BLOCKS["denoise"]() +print(denoise_blocks) +``` + +```out +StableDiffusionXLDenoiseLoop( + Class: StableDiffusionXLDenoiseLoopWrapper + + Description: Denoise step that iteratively denoise the latents. + Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method + At each iteration, it runs blocks defined in `blocks` sequencially: + - `StableDiffusionXLLoopBeforeDenoiser` + - `StableDiffusionXLLoopDenoiser` + - `StableDiffusionXLLoopAfterDenoiser` + + + + Components: + scheduler (`EulerDiscreteScheduler`) + guider (`ClassifierFreeGuidance`) + unet (`UNet2DConditionModel`) + + Blocks: + [0] before_denoiser (StableDiffusionXLLoopBeforeDenoiser) + Description: step within the denoising loop that prepare the latent input for the denoiser. This block should be used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`) + + [1] denoiser (StableDiffusionXLLoopDenoiser) + Description: Step within the denoising loop that denoise the latents with guidance. This block should be used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`) + + [2] after_denoiser (StableDiffusionXLLoopAfterDenoiser) + Description: step within the denoising loop that update the latents. This block should be used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`) + +) +``` + + +Img2img diffusion pipeline adds the same noise level across all pixels based on a single strength parameter, however, differential diffusion uses a change map where each pixel value represents when that region should start denoising. Regions with lower change map values get "frozen" earlier in the denoising process by replacing them with noised original latents, effectively giving them fewer denoising steps and thus preserving more of the original image. + +It has a different `prepare_latents` step and `denoise` step. At `parepare_latents` step, it prepares the change map and pre-computes the original noised latents for all timesteps. At each timestep during the denoising process, it selectively applies denoising based on the change map. Additionally, diff-diff does not use the `strengh` parameter, so its `set_timesteps` step is different from the one in image-to-image, but same as `set_timesteps` in text-to-image workflow. + +So, to implement the differential diffusion pipeline, we can use pipeline blocks from image-to-image and text-to-image workflow, and change the `prepare_latents` step and the `denoise` step (more specifically, we only need to change the first part of `denoise` step where we prepare the latent input for the denoiser model). + +Differential diffusion shares exact same pipeline structure as img2img. Here is a flowchart that puts the changes we need to make into the context of the pipeline structure. + + +![DiffDiff Pipeline Structure](https://mermaid.ink/img/pako:eNqVVO9r4kAQ_VeWLQWFKEk00eRDwZpa7Q-ucPfpYpE1mdWlcTdsVmpb-7_fZk1tTCl3J0Sy8968N5kZ9g0nIgUc4pUk-Rr9iuYc6d_Ibs14vlXoQYpNrtqo07lAo1jBTi2AlynysWIa6DJmG7KCBnZpsHHMSqkqNjaxKC5ALRTbQKEgLyosMthVnEvIiYRFRhRwVaBoNpmUT0W7MrTJkUbSdJEInlbwxMDXcQpcsAKq6OH_2mDTODIY4yt0J0ReUaYGnLXiJVChdSsB-enfPhBnhnjT-rCQj-1K_8Ygt62YUAVy8Ykf4FvU6XYu9rpuIGqPpvXSzs_RVEj2KrgiGUp02zNQTHBEM_FcK3BfQbBHd7qAst-PxvW-9WOrypnNylG0G9oRUMYBFeolg-IQTTJSFDqOUkZp-fwsQURZloVnlPpLf2kVSoonCM-SwCUuqY6dZ5aqddjLd1YiMiFLNrWorrxj9EOmP4El37lsl_9p5PzFqIqwVwgdN981fDM94bphH5I06R8NXZ_4QcPQPTFs6JltPrS6JssFhw9N817l27bdyM-lSKAo6iVBAAnQY0n9wLO9wbcluY7ruUFDtdguH74K0yENKDkK-8nAG6TfNrfy_bf-HjdrlOfZS7VYSAlU5JAwyhLE9WrWVw1dWdPTXauDsy8LUkdHtnX_pfMnBOvSGluRNbGurbuTHtdZN9Zts1MljC19_7EUh0puwcIbkBtSHvFbic6xWsMG5jjUrymRT3M85-86Jyf8txCbjzQptqs1DinJCn3a5qm-viJG9M26OUYlcH0_jsWWKxwGttHA4Rve4dD1el3H8_yh49hD3_X7roVfcNhx-l3b14PxvGHQ0xMa9t4t_Gp8na7tDvu-4w08HXecweD9D4X54ZI) + +ok now we've identified the blocks to modify, let's build the pipeline skeleton first - at this stage, our goal is to get the pipeline struture working end-to-end (even though it's just doing the img2img behavior). I would simply create placeholder blocks by copying from existing ones: + +```python +# Copy existing blocks as placeholders +class SDXLDiffDiffPrepareLatentsStep(PipelineBlock): + """Copied from StableDiffusionXLImg2ImgPrepareLatentsStep - will modify later""" + # ... same implementation as StableDiffusionXLImg2ImgPrepareLatentsStep + +class SDXLDiffDiffLoopBeforeDenoiser(PipelineBlock): + """Copied from StableDiffusionXLLoopBeforeDenoiser - will modify later""" + # ... same implementation as StableDiffusionXLLoopBeforeDenoiser +``` + +`SDXLDiffDiffLoopBeforeDenoiser` is the be part of the denoise loop we need to change. Let's use it to assemble a `SDXLDiffDiffDenoiseLoop`. + +```python +class SDXLDiffDiffDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLLoopDenoiser, StableDiffusionXLLoopAfterDenoiser] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] +``` + +Now we can put together our differential diffusion pipeline. + +```python +DIFFDIFF_BLOCKS = IMAGE2IMAGE_BLOCKS.copy() +DIFFDIFF_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"] +DIFFDIFF_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep +DIFFDIFF_BLOCKS["denoise"] = SDXLDiffDiffDenoiseLoop + +dd_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_BLOCKS) +print(dd_blocks) +# At this point, the pipeline works exactly like img2img since our blocks are just copies +``` + +ok, so now our blocks should be able to compile without an error, we can move on to the next step. Let's setup a simple exapmple so we can run the pipeline as we build it. diff-diff use same components as SDXL so we can fetch the models from a regular SDXL repo. + +```python +dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff") +dd_pipeline.load_componenets(torch_dtype=torch.float16) +dd_pipeline.to("cuda") +``` + +We will use this example script: + +```python + +image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true") +mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true") + +prompt = "a green pear" +negative_prompt = "blurry" + +image = dd_pipeline.run( + prompt=prompt, + negative_prompt=negative_prompt, + num_inference_steps=25, + diffdiff_map=mask, + image=image, + output="images" +)[0] + +image.save("diffdiff_out.png") +``` + +If you run the script right now, you will get a complaint about unexpected input `diffdiff_map`. +and you would get the same result as the original img2img pipeline. + +Let's modify the pipeline so that we can get expected result with this example script. + +We'll start with the `prepare_latents` step, as it is the first step that gets called right after the `input` step. The main changes are: +- new input `diffdiff_map`: It will become a new input to the pipeline after we built it. +- `num_inference_steps` and `timestesp` as intermediates inputs: Both variables are created in `set_timesteps` block, we need to list them as intermediates inputs so that we can now use them in `__call__`. +- A new component `mask_processor`: A default one will be created when we build the pipeline, but user can update it. +- Inside `__call__`, we created 2 new variables: the change map `diffdiff_mask` and the pre-computed noised latents for all timesteps `original_latents`. We also need to list them as intermediates outputs so the we can use them in the `denoise` step later. + +I have two tips I want to share for this process: +1. use `print(dd_pipeline.doc)` to check compiled inputs and outputs of the built piepline. +e.g. after we added `diffdiff_map` as an input in this step, we can run `print(dd_pipeline.doc)` to verify that it shows up in the docstring as a user input. +2. insert `print(state)` and `print(block_state)` everywhere inside the `__call__` method to inspect the intermediate results. + +This is the modified `StableDiffusionXLImg2ImgPrepareLatentsStep` we ended up with : +```diff +- class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): ++ class SDXLDiffDiffPrepareLatentsStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return ( +- "Step that prepares the latents for the image-to-image generation process" ++ "Step that prepares the latents for the differential diffusion generation process" + ) + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec("scheduler", EulerDiscreteScheduler), ++ ComponentSpec( ++ "mask_processor", ++ VaeImageProcessor, ++ config=FrozenDict({"do_normalize": False, "do_convert_grayscale": True}), ++ default_creation_method="from_config", ++ ) + ] + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ ++ InputParam("diffdiff_map",required=True), + ] + + @property + def intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam("generator"), +- InputParam("latent_timestep", required=True, type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step."), ++ InputParam("timesteps",type_hint=torch.Tensor, description="The timesteps to use for sampling. Can be generated in set_timesteps step."), ++ InputParam("num_inference_steps", type_hint=int, description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step."), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [ ++ OutputParam("original_latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"), ++ OutputParam("diffdiff_masks", type_hint=torch.Tensor, description="The masks used for the differential diffusion denoising process"), + ] + + @torch.no_grad() + def __call__(self, components, state: PipelineState): + block_state = self.get_block_state(state) + block_state.dtype = components.vae.dtype + block_state.device = components._execution_device + + block_state.add_noise = True if block_state.denoising_start is None else False ++ components.scheduler.set_begin_index(None) + + if block_state.latents is None: + block_state.latents = prepare_latents_img2img( + components.vae, + components.scheduler, + block_state.image_latents, +- block_state.latent_timestep, ++ block_state.timesteps, + block_state.batch_size, + block_state.num_images_per_prompt, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.add_noise, + ) ++ ++ latent_height = block_state.image_latents.shape[-2] ++ latent_width = block_state.image_latents.shape[-1] ++ diffdiff_map = components.mask_processor.preprocess(block_state.diffdiff_map, height=latent_height, width=latent_width) ++ ++ diffdiff_map = diffdiff_map.squeeze(0).to(block_state.device) ++ thresholds = torch.arange(block_state.num_inference_steps, dtype=diffdiff_map.dtype) / block_state.num_inference_steps ++ thresholds = thresholds.unsqueeze(1).unsqueeze(1).to(block_state.device) ++ block_state.diffdiff_masks = diffdiff_map > (thresholds + (block_state.denoising_start or 0)) ++ block_state.original_latents = block_state.latents + + self.add_block_state(state, block_state) +``` + +This is the modified `before_denoiser` step, we use diff-diff map to freeze certain regions in the latents before each denoising step. + +```diff +class SDXLDiffDiffLoopBeforeDenoiser(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return ( +- "step within the denoising loop that prepare the latent input for the denoiser" ++ "Step within the denoising loop for differential diffusion that prepare the latent input for the denoiser" + ) + ++ @property ++ def inputs(self) -> List[Tuple[str, Any]]: ++ return [ ++ InputParam("denoising_start"), ++ ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), ++ InputParam( ++ "original_latents", ++ type_hint=torch.Tensor, ++ description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." ++ ), ++ InputParam( ++ "diffdiff_masks", ++ type_hint=torch.Tensor, ++ description="The masks used for the differential diffusion denoising process, can be generated in prepare_latent step." ++ ), + ] + + @torch.no_grad() + def __call__(self, components, block_state, i, t): ++ # diff diff ++ if i == 0 and block_state.denoising_start is None: ++ block_state.latents = block_state.original_latents[:1] ++ else: ++ block_state.mask = block_state.diffdiff_masks[i].unsqueeze(0) ++ # cast mask to the same type as latents etc ++ block_state.mask = block_state.mask.to(block_state.latents.dtype) ++ block_state.mask = block_state.mask.unsqueeze(1) # fit shape ++ block_state.latents = block_state.original_latents[i] * block_state.mask + block_state.latents * (1 - block_state.mask) ++ # end diff diff + ++ # expand the latents if we are doing classifier free guidance + block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) + + return components, block_state +``` + +That's all there is to it! Now your script should run as expected and get a result like this one. + +Here is the pipeline we created ( hint, `print(dd_blocks)`) +It is a simple sequential pipeline. + +``` +SequentialPipelineBlocks( + Class: ModularPipelineBlocks + + Description: + + + Components: + text_encoder (`CLIPTextModel`) + text_encoder_2 (`CLIPTextModelWithProjection`) + tokenizer (`CLIPTokenizer`) + tokenizer_2 (`CLIPTokenizer`) + guider (`ClassifierFreeGuidance`) + vae (`AutoencoderKL`) + image_processor (`VaeImageProcessor`) + scheduler (`EulerDiscreteScheduler`) + mask_processor (`VaeImageProcessor`) + unet (`UNet2DConditionModel`) + + Configs: + force_zeros_for_empty_prompt (default: True) + requires_aesthetics_score (default: False) + + Blocks: + [0] text_encoder (StableDiffusionXLTextEncoderStep) + Description: Text Encoder step that generate text_embeddings to guide the image generation + + [1] image_encoder (StableDiffusionXLVaeEncoderStep) + Description: Vae Encoder step that encode the input image into a latent representation + + [2] input (StableDiffusionXLInputStep) + Description: Input processing step that: + 1. Determines `batch_size` and `dtype` based on `prompt_embeds` + 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt` + + All input tensors are expected to have either batch_size=1 or match the batch_size + of prompt_embeds. The tensors will be duplicated across the batch dimension to + have a final batch_size of batch_size * num_images_per_prompt. + + [3] set_timesteps (StableDiffusionXLSetTimestepsStep) + Description: Step that sets the scheduler's timesteps for inference + + [4] prepare_latents (SDXLDiffDiffPrepareLatentsStep) + Description: Step that prepares the latents for the differential diffusion generation process + + [5] prepare_add_cond (StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep) + Description: Step that prepares the additional conditioning for the image-to-image/inpainting generation process + + [6] denoise (SDXLDiffDiffDenoiseLoop) + Description: Pipeline block that iteratively denoise the latents over `timesteps`. The specific steps with each iteration can be customized with `blocks` attributes + + [7] decode (StableDiffusionXLDecodeStep) + Description: Step that decodes the denoised latents into images + +) +``` + +Now if you run the example we prepared earlier, you should see an apple with its right half transformed into a green pear. + +![Image description](https://cdn-uploads.huggingface.co/production/uploads/624ef9ba9d608e459387b34e/4zqJOz-35Q0i6jyUW3liL.png) + + +## Adding IP-adapter + +We provide an auto IP-adapter block that you can plug-and-play into your modular workflow. It's an `AutoPipelineBlocks`, so it will only run when the user passes an IP adapter image. In this tutorial, we'll focus on how to package it into your differential diffusion workflow. To learn more about `AutoPipelineBlocks`, see [here](TODO) + +Let's create IP-adapter block: + +```python +from diffusers.modular_pipelines.stable_diffusion_xl.encoders import StableDiffusionXLAutoIPAdapterStep +ip_adapter_block = StableDiffusionXLAutoIPAdapterStep() +print(ip_adapter_block) +``` + +It has 4 components: `unet` and `guider` are already used in diff-diff, but it also has two new ones: `image_encoder` and `feature_extractor` + +```out + ip adapter block: StableDiffusionXLAutoIPAdapterStep( + Class: AutoPipelineBlocks + + ==================================================================================================== + This pipeline contains blocks that are selected at runtime based on inputs. + Trigger Inputs: {'ip_adapter_image'} + Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('ip_adapter_image')`). + ==================================================================================================== + + + Description: Run IP Adapter step if `ip_adapter_image` is provided. + + + Components: + image_encoder (`CLIPVisionModelWithProjection`) + feature_extractor (`CLIPImageProcessor`) + unet (`UNet2DConditionModel`) + guider (`ClassifierFreeGuidance`) + + Blocks: + • ip_adapter [trigger: ip_adapter_image] (StableDiffusionXLIPAdapterStep) + Description: IP Adapter step that handles all the ip adapter related tasks: Load/unload ip adapter weights into unet, prepare ip adapter image embeddings, etc See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin) for more details + +) +``` + +We can directly add the ip-adapter block instance to the `diffdiff_blocks` that we created before. The `blocks` attribute is a `InsertableOrderedDict`, so we're able to insert the it at specific position (index `0` here). + +```python +dd_blocks.blocks.insert("ip_adapter", ip_adapter_block, 0) +``` + +Take a look at the new diff-diff pipeline with ip-adapter! + +```python +print(dd_blocks) +``` + +The pipeline now lists ip-adapter as its first block, and tells you that it will run only if `ip_adapter_image` is provided. It also includes the two new components from ip-adpater: `image_encoder` and `feature_extractor` + +```out +SequentialPipelineBlocks( + Class: ModularPipelineBlocks + + ==================================================================================================== + This pipeline contains blocks that are selected at runtime based on inputs. + Trigger Inputs: {'ip_adapter_image'} + Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('ip_adapter_image')`). + ==================================================================================================== + + + Description: + + + Components: + image_encoder (`CLIPVisionModelWithProjection`) + feature_extractor (`CLIPImageProcessor`) + unet (`UNet2DConditionModel`) + guider (`ClassifierFreeGuidance`) + text_encoder (`CLIPTextModel`) + text_encoder_2 (`CLIPTextModelWithProjection`) + tokenizer (`CLIPTokenizer`) + tokenizer_2 (`CLIPTokenizer`) + vae (`AutoencoderKL`) + image_processor (`VaeImageProcessor`) + scheduler (`EulerDiscreteScheduler`) + mask_processor (`VaeImageProcessor`) + + Configs: + force_zeros_for_empty_prompt (default: True) + requires_aesthetics_score (default: False) + + Blocks: + [0] ip_adapter (StableDiffusionXLAutoIPAdapterStep) + Description: Run IP Adapter step if `ip_adapter_image` is provided. + + [1] text_encoder (StableDiffusionXLTextEncoderStep) + Description: Text Encoder step that generate text_embeddings to guide the image generation + + [2] image_encoder (StableDiffusionXLVaeEncoderStep) + Description: Vae Encoder step that encode the input image into a latent representation + + [3] input (StableDiffusionXLInputStep) + Description: Input processing step that: + 1. Determines `batch_size` and `dtype` based on `prompt_embeds` + 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt` + + All input tensors are expected to have either batch_size=1 or match the batch_size + of prompt_embeds. The tensors will be duplicated across the batch dimension to + have a final batch_size of batch_size * num_images_per_prompt. + + [4] set_timesteps (StableDiffusionXLSetTimestepsStep) + Description: Step that sets the scheduler's timesteps for inference + + [5] prepare_latents (SDXLDiffDiffPrepareLatentsStep) + Description: Step that prepares the latents for the differential diffusion generation process + + [6] prepare_add_cond (StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep) + Description: Step that prepares the additional conditioning for the image-to-image/inpainting generation process + + [7] denoise (SDXLDiffDiffDenoiseLoop) + Description: Pipeline block that iteratively denoise the latents over `timesteps`. The specific steps with each iteration can be customized with `blocks` attributes + + [8] decode (StableDiffusionXLDecodeStep) + Description: Step that decodes the denoised latents into images + +) +``` + +Let's test it out. I used an orange image to condition the generation via ip-addapter and we can see a slight orange color and texture in the final output. + + +```python +ip_adapter_block = StableDiffusionXLAutoIPAdapterStep() +dd_blocks.blocks.insert("ip_adapter", ip_adapter_block, 0) + +dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff") +dd_pipeline.load_components(torch_dtype=torch.float16) +dd_pipeline.loader.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin") +dd_pipeline.loader.set_ip_adapter_scale(0.6) +dd_pipeline = dd_pipeline.to(device) + +ip_adapter_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/diffdiff_orange.jpeg") +image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true") +mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true") + +prompt = "a green pear" +negative_prompt = "blurry" +generator = torch.Generator(device=device).manual_seed(42) + +image = dd_pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + num_inference_steps=25, + generator=generator, + ip_adapter_image=ip_adapter_image, + diffdiff_map=mask, + image=image, + output="images" +)[0] + +``` + +## Working with ControlNets + +What about controlnet? Can differential diffusion work with controlnet? The key differences between a regular pipeline and a ControlNet pipeline are: + * A ControlNet input step that prepares the control condition + * Inside the denoising loop, a modified denoiser step where the control image is first processed through ControlNet, then control information is injected into the UNet + +From looking at the code workflow: differential diffusion only modifies the "before denoiser" step, while ControlNet operates within the "denoiser" itself. Since they intervene at different points in the pipeline, they should work together without conflicts. + +Intuitively, these two techniques are orthogonal and should combine naturally: differential diffusion controls how much the inference process can deviate from the original in each region, while ControlNet controls in what direction that change occurs. + +With this understanding, let's assemble the `SDXLDiffDiffControlNetDenoiseLoop`: + +```python +class SDXLDiffDiffControlNetDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLControlNetLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + +controlnet_denoise_block = SDXLDiffDiffControlNetDenoiseLoop() +# print(controlnet_denoise) +``` + +We provide a auto controlnet input block that you can directly put into your workflow: similar to auto ip-adapter block, this step will only run if `control_image` input is passed from user. It work with both controlnet and controlnet union. + + +```python +from diffusers.modular_pipelines.stable_diffusion_xl.before_denoise import StableDiffusionXLControlNetAutoInput +control_input_block = StableDiffusionXLControlNetAutoInput() +print(control_input_block) +``` + +```out +StableDiffusionXLControlNetAutoInput( + Class: AutoPipelineBlocks + + ==================================================================================================== + This pipeline contains blocks that are selected at runtime based on inputs. + Trigger Inputs: {'control_image', 'control_mode'} + Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('control_image')`). + ==================================================================================================== + + + Description: Controlnet Input step that prepare the controlnet input. + This is an auto pipeline block that works for both controlnet and controlnet_union. + - `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided. + - `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided. + + + Components: + controlnet (`ControlNetUnionModel`) + control_image_processor (`VaeImageProcessor`) + + Blocks: + • controlnet_union [trigger: control_mode] (StableDiffusionXLControlNetUnionInputStep) + Description: step that prepares inputs for the ControlNetUnion model + + • controlnet [trigger: control_image] (StableDiffusionXLControlNetInputStep) + Description: step that prepare inputs for controlnet + +) +``` + +Let's assemble the blocks and run an example using controlnet + differential diffusion. I used a canny of a tomato as `control_image`, so you can see in the output, the right half that transformed into a pear had a tomato-like shape. + +```python +dd_blocks.blocks.insert("controlnet_input", control_input_block, 7) +dd_blocks.blocks["denoise"] = controlnet_denoise_block + +dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff") +dd_pipeline.load_components(torch_dtype=torch.float16) +dd_pipeline = dd_pipeline.to(device) + +control_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/diffdiff_tomato_canny.jpeg") +image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true") +mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true") + +prompt = "a green pear" +negative_prompt = "blurry" +generator = torch.Generator(device=device).manual_seed(42) + +image = dd_pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + num_inference_steps=25, + generator=generator, + control_image=control_image, + controlnet_conditioning_scale=0.5, + diffdiff_map=mask, + image=image, + output="images" +)[0] +``` + +Optionally, We can combine `SDXLDiffDiffControlNetDenoiseLoop` and `SDXLDiffDiffDenoiseLoop` into a `AutoPipelineBlocks` so that same workflow can work with or without controlnet. + + +```python +class SDXLDiffDiffAutoDenoiseStep(AutoPipelineBlocks): + block_classes = [SDXLDiffDiffControlNetDenoiseLoop, SDXLDiffDiffDenoiseLoop] + block_names = ["controlnet_denoise", "denoise"] + block_trigger_inputs = ["controlnet_cond", None] +``` + +`SDXLDiffDiffAutoDenoiseStep` will run the ControlNet denoise step if `control_image` input is provided, otherwise it will run the regular denoise step. + +We won't go into too much detail about `AutoPipelineBlocks` in this section, but you can read more about it [here](TODO). Note that it's perfectly fine not to use `AutoPipelineBlocks`. In fact, we recommend only using `AutoPipelineBlocks` to package your workflow at the end once you've verified all your pipelines work as expected. + +now you can create the differential diffusion preset that works with ip-adapter & controlnet. + +```python +DIFFDIFF_AUTO_BLOCKS = IMAGE2IMAGE_BLOCKS.copy() +DIFFDIFF_AUTO_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep +DIFFDIFF_AUTO_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"] +DIFFDIFF_AUTO_BLOCKS["denoise"] = SDXLDiffDiffAutoDenoiseStep +DIFFDIFF_AUTO_BLOCKS.insert("ip_adapter", StableDiffusionXLAutoIPAdapterStep, 0) +DIFFDIFF_AUTO_BLOCKS.insert("controlnet_input",StableDiffusionXLControlNetAutoInput, 7) + +print(DIFFDIFF_AUTO_BLOCKS) +``` + +to use + +```python +dd_auto_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_AUTO_BLOCKS) +dd_pipeline = dd_auto_blocks.init_pipeline(...) +``` +## Creating a Modular Repo + +You can easily share your differential diffusion workflow on the hub, by creating a modular repo like this https://huggingface.co/YiYiXu/modular-diffdiff + +[YiYi TODO: add details tutorial on how to create the modular repo, building upon this https://github.com/huggingface/diffusers/pull/11462] + +With a modular repo, it is very easy for the community to use the workflow you just created! + +```python + +from diffusers.modular_pipelines import ModularPipeline, ComponentsManager +import torch +from diffusers.utils import load_image + +repo_id = "YiYiXu/modular-diffdiff" + +components = ComponentsManager() + +diffdiff_pipeline = ModularPipeline.from_pretrained(repo_id, trust_remote_code=True, component_manager=components, collection="diffdiff") +diffdiff_pipeline.loader.load(torch_dtype=torch.float16) +components.enable_auto_cpu_offload() +``` + +see more usage example on model card + +## deploy a mellon node + +YIYI TODO: an example of mellon node https://huggingface.co/YiYiXu/diff-diff-mellon From 42c06e90f42ba260150244f736c2b335fc6fcec9 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 23 Jun 2025 17:55:32 +0200 Subject: [PATCH 078/170] update doc --- .../en/modular_diffusers/developer_guide.md | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/docs/source/en/modular_diffusers/developer_guide.md b/docs/source/en/modular_diffusers/developer_guide.md index 175383f0d0ca..8ab71a2c7bbd 100644 --- a/docs/source/en/modular_diffusers/developer_guide.md +++ b/docs/source/en/modular_diffusers/developer_guide.md @@ -1,6 +1,22 @@ + + # Developer Guide: Building with Modular Diffusers -To implement new pipelines in the modular framework, you can use this process: +[[open-in-colab]] + +In this tutorial we will walk through the process of adding a new pipeline to the modular framework using differential diffusion as our example. We'll cover the complete workflow from implementation to deployment: implementing the new pipeline, ensuring compatibility with existing tools, sharing the code on Hugging Face Hub, and deploying it as a UI node. + +We'll also demonstrate the 3-step framework process we use for implementing new basic pipelines in the modular system. #### 1. **Start with an existing pipeline as a base** - Identify which existing pipeline is most similar to your target From 1ae591e81705526704904909e2119b358b743fc3 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 23 Jun 2025 18:08:55 +0200 Subject: [PATCH 079/170] update code format --- .../en/modular_diffusers/developer_guide.md | 327 +++++++++--------- 1 file changed, 162 insertions(+), 165 deletions(-) diff --git a/docs/source/en/modular_diffusers/developer_guide.md b/docs/source/en/modular_diffusers/developer_guide.md index 8ab71a2c7bbd..7dc0a682f543 100644 --- a/docs/source/en/modular_diffusers/developer_guide.md +++ b/docs/source/en/modular_diffusers/developer_guide.md @@ -38,24 +38,24 @@ Let's see how this works with the Differential Diffusion example. Differential diffusion (https://differential-diffusion.github.io/) is an image-to-image workflow, so it makes sense for us to start with the preset of pipeline blocks used to build img2img pipeline (`IMAGE2IMAGE_BLOCKS`) and see how we can build this new pipeline with them. -```python -IMAGE2IMAGE_BLOCKS = InsertableOrderedDict([ - ("text_encoder", StableDiffusionXLTextEncoderStep), - ("image_encoder", StableDiffusionXLVaeEncoderStep), - ("input", StableDiffusionXLInputStep), - ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), - ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep), - ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), - ("denoise", StableDiffusionXLDenoiseLoop), - ("decode", StableDiffusionXLDecodeStep) -]) +```py +>>> IMAGE2IMAGE_BLOCKS = InsertableOrderedDict([ +... ("text_encoder", StableDiffusionXLTextEncoderStep), +... ("image_encoder", StableDiffusionXLVaeEncoderStep), +... ("input", StableDiffusionXLInputStep), +... ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), +... ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep), +... ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), +... ("denoise", StableDiffusionXLDenoiseLoop), +... ("decode", StableDiffusionXLDecodeStep) +... ]) ``` Note that "denoise" (`StableDiffusionXLDenoiseLoop`) is a loop that contains 3 loop blocks (more on SequentialLoopBlocks [here](https://colab.research.google.com/drive/1iVRjy_tOfmmm4gd0iVe0_Rl3c6cBzVqi?usp=sharing)) -```python -denoise_blocks = IMAGE2IMAGE_BLOCKS["denoise"]() -print(denoise_blocks) +```py +>>> denoise_blocks = IMAGE2IMAGE_BLOCKS["denoise"]() +>>> print(denoise_blocks) ``` ```out @@ -103,66 +103,65 @@ Differential diffusion shares exact same pipeline structure as img2img. Here is ok now we've identified the blocks to modify, let's build the pipeline skeleton first - at this stage, our goal is to get the pipeline struture working end-to-end (even though it's just doing the img2img behavior). I would simply create placeholder blocks by copying from existing ones: -```python -# Copy existing blocks as placeholders -class SDXLDiffDiffPrepareLatentsStep(PipelineBlock): - """Copied from StableDiffusionXLImg2ImgPrepareLatentsStep - will modify later""" - # ... same implementation as StableDiffusionXLImg2ImgPrepareLatentsStep - -class SDXLDiffDiffLoopBeforeDenoiser(PipelineBlock): - """Copied from StableDiffusionXLLoopBeforeDenoiser - will modify later""" - # ... same implementation as StableDiffusionXLLoopBeforeDenoiser +```py +>>> # Copy existing blocks as placeholders +>>> class SDXLDiffDiffPrepareLatentsStep(PipelineBlock): +... """Copied from StableDiffusionXLImg2ImgPrepareLatentsStep - will modify later""" +... # ... same implementation as StableDiffusionXLImg2ImgPrepareLatentsStep +... +>>> class SDXLDiffDiffLoopBeforeDenoiser(PipelineBlock): +... """Copied from StableDiffusionXLLoopBeforeDenoiser - will modify later""" +... # ... same implementation as StableDiffusionXLLoopBeforeDenoiser ``` `SDXLDiffDiffLoopBeforeDenoiser` is the be part of the denoise loop we need to change. Let's use it to assemble a `SDXLDiffDiffDenoiseLoop`. -```python -class SDXLDiffDiffDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): - block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLLoopDenoiser, StableDiffusionXLLoopAfterDenoiser] - block_names = ["before_denoiser", "denoiser", "after_denoiser"] +```py +>>> class SDXLDiffDiffDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): +... block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLLoopDenoiser, StableDiffusionXLLoopAfterDenoiser] +... block_names = ["before_denoiser", "denoiser", "after_denoiser"] ``` Now we can put together our differential diffusion pipeline. -```python -DIFFDIFF_BLOCKS = IMAGE2IMAGE_BLOCKS.copy() -DIFFDIFF_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"] -DIFFDIFF_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep -DIFFDIFF_BLOCKS["denoise"] = SDXLDiffDiffDenoiseLoop - -dd_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_BLOCKS) -print(dd_blocks) -# At this point, the pipeline works exactly like img2img since our blocks are just copies +```py +>>> DIFFDIFF_BLOCKS = IMAGE2IMAGE_BLOCKS.copy() +>>> DIFFDIFF_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"] +>>> DIFFDIFF_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep +>>> DIFFDIFF_BLOCKS["denoise"] = SDXLDiffDiffDenoiseLoop +>>> +>>> dd_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_BLOCKS) +>>> print(dd_blocks) +>>> # At this point, the pipeline works exactly like img2img since our blocks are just copies ``` ok, so now our blocks should be able to compile without an error, we can move on to the next step. Let's setup a simple exapmple so we can run the pipeline as we build it. diff-diff use same components as SDXL so we can fetch the models from a regular SDXL repo. -```python -dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff") -dd_pipeline.load_componenets(torch_dtype=torch.float16) -dd_pipeline.to("cuda") +```py +>>> dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff") +>>> dd_pipeline.load_componenets(torch_dtype=torch.float16) +>>> dd_pipeline.to("cuda") ``` We will use this example script: -```python - -image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true") -mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true") - -prompt = "a green pear" -negative_prompt = "blurry" - -image = dd_pipeline.run( - prompt=prompt, - negative_prompt=negative_prompt, - num_inference_steps=25, - diffdiff_map=mask, - image=image, - output="images" -)[0] - -image.save("diffdiff_out.png") +```py +>>> image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true") +>>> mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true") +>>> +>>> prompt = "a green pear" +>>> negative_prompt = "blurry" +>>> +>>> image = dd_pipeline.run( +... prompt=prompt, +... negative_prompt=negative_prompt, +... num_inference_steps=25, +... diffdiff_map=mask, +... image=image, +... output="images" +... )[0] +>>> +>>> image.save("diffdiff_out.png") ``` If you run the script right now, you will get a complaint about unexpected input `diffdiff_map`. @@ -330,7 +329,7 @@ That's all there is to it! Now your script should run as expected and get a resu Here is the pipeline we created ( hint, `print(dd_blocks)`) It is a simple sequential pipeline. -``` +```out SequentialPipelineBlocks( Class: ModularPipelineBlocks @@ -398,10 +397,10 @@ We provide an auto IP-adapter block that you can plug-and-play into your modular Let's create IP-adapter block: -```python -from diffusers.modular_pipelines.stable_diffusion_xl.encoders import StableDiffusionXLAutoIPAdapterStep -ip_adapter_block = StableDiffusionXLAutoIPAdapterStep() -print(ip_adapter_block) +```py +>>> from diffusers.modular_pipelines.stable_diffusion_xl.encoders import StableDiffusionXLAutoIPAdapterStep +>>> ip_adapter_block = StableDiffusionXLAutoIPAdapterStep() +>>> print(ip_adapter_block) ``` It has 4 components: `unet` and `guider` are already used in diff-diff, but it also has two new ones: `image_encoder` and `feature_extractor` @@ -435,14 +434,14 @@ It has 4 components: `unet` and `guider` are already used in diff-diff, but it a We can directly add the ip-adapter block instance to the `diffdiff_blocks` that we created before. The `blocks` attribute is a `InsertableOrderedDict`, so we're able to insert the it at specific position (index `0` here). -```python -dd_blocks.blocks.insert("ip_adapter", ip_adapter_block, 0) +```py +>>> dd_blocks.blocks.insert("ip_adapter", ip_adapter_block, 0) ``` Take a look at the new diff-diff pipeline with ip-adapter! -```python -print(dd_blocks) +```py +>>> print(dd_blocks) ``` The pipeline now lists ip-adapter as its first block, and tells you that it will run only if `ip_adapter_image` is provided. It also includes the two new components from ip-adpater: `image_encoder` and `feature_extractor` @@ -519,35 +518,34 @@ SequentialPipelineBlocks( Let's test it out. I used an orange image to condition the generation via ip-addapter and we can see a slight orange color and texture in the final output. -```python -ip_adapter_block = StableDiffusionXLAutoIPAdapterStep() -dd_blocks.blocks.insert("ip_adapter", ip_adapter_block, 0) - -dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff") -dd_pipeline.load_components(torch_dtype=torch.float16) -dd_pipeline.loader.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin") -dd_pipeline.loader.set_ip_adapter_scale(0.6) -dd_pipeline = dd_pipeline.to(device) - -ip_adapter_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/diffdiff_orange.jpeg") -image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true") -mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true") - -prompt = "a green pear" -negative_prompt = "blurry" -generator = torch.Generator(device=device).manual_seed(42) - -image = dd_pipeline( - prompt=prompt, - negative_prompt=negative_prompt, - num_inference_steps=25, - generator=generator, - ip_adapter_image=ip_adapter_image, - diffdiff_map=mask, - image=image, - output="images" -)[0] - +```py +>>> ip_adapter_block = StableDiffusionXLAutoIPAdapterStep() +>>> dd_blocks.blocks.insert("ip_adapter", ip_adapter_block, 0) +>>> +>>> dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff") +>>> dd_pipeline.load_components(torch_dtype=torch.float16) +>>> dd_pipeline.loader.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin") +>>> dd_pipeline.loader.set_ip_adapter_scale(0.6) +>>> dd_pipeline = dd_pipeline.to(device) +>>> +>>> ip_adapter_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/diffdiff_orange.jpeg") +>>> image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true") +>>> mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true") +>>> +>>> prompt = "a green pear" +>>> negative_prompt = "blurry" +>>> generator = torch.Generator(device=device).manual_seed(42) +>>> +>>> image = dd_pipeline( +... prompt=prompt, +... negative_prompt=negative_prompt, +... num_inference_steps=25, +... generator=generator, +... ip_adapter_image=ip_adapter_image, +... diffdiff_map=mask, +... image=image, +... output="images" +... )[0] ``` ## Working with ControlNets @@ -562,22 +560,22 @@ Intuitively, these two techniques are orthogonal and should combine naturally: d With this understanding, let's assemble the `SDXLDiffDiffControlNetDenoiseLoop`: -```python -class SDXLDiffDiffControlNetDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): - block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLControlNetLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser] - block_names = ["before_denoiser", "denoiser", "after_denoiser"] - -controlnet_denoise_block = SDXLDiffDiffControlNetDenoiseLoop() -# print(controlnet_denoise) +```py +>>> class SDXLDiffDiffControlNetDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): +... block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLControlNetLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser] +... block_names = ["before_denoiser", "denoiser", "after_denoiser"] +>>> +>>> controlnet_denoise_block = SDXLDiffDiffControlNetDenoiseLoop() +>>> # print(controlnet_denoise) ``` We provide a auto controlnet input block that you can directly put into your workflow: similar to auto ip-adapter block, this step will only run if `control_image` input is passed from user. It work with both controlnet and controlnet union. -```python -from diffusers.modular_pipelines.stable_diffusion_xl.before_denoise import StableDiffusionXLControlNetAutoInput -control_input_block = StableDiffusionXLControlNetAutoInput() -print(control_input_block) +```py +>>> from diffusers.modular_pipelines.stable_diffusion_xl.before_denoise import StableDiffusionXLControlNetAutoInput +>>> control_input_block = StableDiffusionXLControlNetAutoInput() +>>> print(control_input_block) ``` ```out @@ -613,43 +611,43 @@ StableDiffusionXLControlNetAutoInput( Let's assemble the blocks and run an example using controlnet + differential diffusion. I used a canny of a tomato as `control_image`, so you can see in the output, the right half that transformed into a pear had a tomato-like shape. -```python -dd_blocks.blocks.insert("controlnet_input", control_input_block, 7) -dd_blocks.blocks["denoise"] = controlnet_denoise_block - -dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff") -dd_pipeline.load_components(torch_dtype=torch.float16) -dd_pipeline = dd_pipeline.to(device) - -control_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/diffdiff_tomato_canny.jpeg") -image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true") -mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true") - -prompt = "a green pear" -negative_prompt = "blurry" -generator = torch.Generator(device=device).manual_seed(42) - -image = dd_pipeline( - prompt=prompt, - negative_prompt=negative_prompt, - num_inference_steps=25, - generator=generator, - control_image=control_image, - controlnet_conditioning_scale=0.5, - diffdiff_map=mask, - image=image, - output="images" -)[0] +```py +>>> dd_blocks.blocks.insert("controlnet_input", control_input_block, 7) +>>> dd_blocks.blocks["denoise"] = controlnet_denoise_block +>>> +>>> dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff") +>>> dd_pipeline.load_components(torch_dtype=torch.float16) +>>> dd_pipeline = dd_pipeline.to(device) +>>> +>>> control_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/diffdiff_tomato_canny.jpeg") +>>> image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true") +>>> mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true") +>>> +>>> prompt = "a green pear" +>>> negative_prompt = "blurry" +>>> generator = torch.Generator(device=device).manual_seed(42) +>>> +>>> image = dd_pipeline( +... prompt=prompt, +... negative_prompt=negative_prompt, +... num_inference_steps=25, +... generator=generator, +... control_image=control_image, +... controlnet_conditioning_scale=0.5, +... diffdiff_map=mask, +... image=image, +... output="images" +... )[0] ``` Optionally, We can combine `SDXLDiffDiffControlNetDenoiseLoop` and `SDXLDiffDiffDenoiseLoop` into a `AutoPipelineBlocks` so that same workflow can work with or without controlnet. -```python -class SDXLDiffDiffAutoDenoiseStep(AutoPipelineBlocks): - block_classes = [SDXLDiffDiffControlNetDenoiseLoop, SDXLDiffDiffDenoiseLoop] - block_names = ["controlnet_denoise", "denoise"] - block_trigger_inputs = ["controlnet_cond", None] +```py +>>> class SDXLDiffDiffAutoDenoiseStep(AutoPipelineBlocks): +... block_classes = [SDXLDiffDiffControlNetDenoiseLoop, SDXLDiffDiffDenoiseLoop] +... block_names = ["controlnet_denoise", "denoise"] +... block_trigger_inputs = ["controlnet_cond", None] ``` `SDXLDiffDiffAutoDenoiseStep` will run the ControlNet denoise step if `control_image` input is provided, otherwise it will run the regular denoise step. @@ -658,22 +656,22 @@ We won't go into too much detail about `AutoPipelineBlocks` in this section, but now you can create the differential diffusion preset that works with ip-adapter & controlnet. -```python -DIFFDIFF_AUTO_BLOCKS = IMAGE2IMAGE_BLOCKS.copy() -DIFFDIFF_AUTO_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep -DIFFDIFF_AUTO_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"] -DIFFDIFF_AUTO_BLOCKS["denoise"] = SDXLDiffDiffAutoDenoiseStep -DIFFDIFF_AUTO_BLOCKS.insert("ip_adapter", StableDiffusionXLAutoIPAdapterStep, 0) -DIFFDIFF_AUTO_BLOCKS.insert("controlnet_input",StableDiffusionXLControlNetAutoInput, 7) - -print(DIFFDIFF_AUTO_BLOCKS) +```py +>>> DIFFDIFF_AUTO_BLOCKS = IMAGE2IMAGE_BLOCKS.copy() +>>> DIFFDIFF_AUTO_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep +>>> DIFFDIFF_AUTO_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"] +>>> DIFFDIFF_AUTO_BLOCKS["denoise"] = SDXLDiffDiffAutoDenoiseStep +>>> DIFFDIFF_AUTO_BLOCKS.insert("ip_adapter", StableDiffusionXLAutoIPAdapterStep, 0) +>>> DIFFDIFF_AUTO_BLOCKS.insert("controlnet_input",StableDiffusionXLControlNetAutoInput, 7) +>>> +>>> print(DIFFDIFF_AUTO_BLOCKS) ``` to use -```python -dd_auto_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_AUTO_BLOCKS) -dd_pipeline = dd_auto_blocks.init_pipeline(...) +```py +>>> dd_auto_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_AUTO_BLOCKS) +>>> dd_pipeline = dd_auto_blocks.init_pipeline(...) ``` ## Creating a Modular Repo @@ -683,23 +681,22 @@ You can easily share your differential diffusion workflow on the hub, by creatin With a modular repo, it is very easy for the community to use the workflow you just created! -```python - -from diffusers.modular_pipelines import ModularPipeline, ComponentsManager -import torch -from diffusers.utils import load_image - -repo_id = "YiYiXu/modular-diffdiff" - -components = ComponentsManager() - -diffdiff_pipeline = ModularPipeline.from_pretrained(repo_id, trust_remote_code=True, component_manager=components, collection="diffdiff") -diffdiff_pipeline.loader.load(torch_dtype=torch.float16) -components.enable_auto_cpu_offload() +```py +>>> from diffusers.modular_pipelines import ModularPipeline, ComponentsManager +>>> import torch +>>> from diffusers.utils import load_image +>>> +>>> repo_id = "YiYiXu/modular-diffdiff" +>>> +>>> components = ComponentsManager() +>>> +>>> diffdiff_pipeline = ModularPipeline.from_pretrained(repo_id, trust_remote_code=True, component_manager=components, collection="diffdiff") +>>> diffdiff_pipeline.loader.load(torch_dtype=torch.float16) +>>> components.enable_auto_cpu_offload() ``` see more usage example on model card ## deploy a mellon node -YIYI TODO: an example of mellon node https://huggingface.co/YiYiXu/diff-diff-mellon +[YIYI TODO: for now, here is an example of mellon node https://huggingface.co/YiYiXu/diff-diff-mellon] From bb4044362ebdfbe460826b55268443cd56c15d1e Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 23 Jun 2025 18:37:28 +0200 Subject: [PATCH 080/170] up --- .../en/modular_diffusers/developer_guide.md | 59 ++++++++++++------- 1 file changed, 39 insertions(+), 20 deletions(-) diff --git a/docs/source/en/modular_diffusers/developer_guide.md b/docs/source/en/modular_diffusers/developer_guide.md index 7dc0a682f543..d4a66b067398 100644 --- a/docs/source/en/modular_diffusers/developer_guide.md +++ b/docs/source/en/modular_diffusers/developer_guide.md @@ -169,16 +169,26 @@ and you would get the same result as the original img2img pipeline. Let's modify the pipeline so that we can get expected result with this example script. -We'll start with the `prepare_latents` step, as it is the first step that gets called right after the `input` step. The main changes are: -- new input `diffdiff_map`: It will become a new input to the pipeline after we built it. -- `num_inference_steps` and `timestesp` as intermediates inputs: Both variables are created in `set_timesteps` block, we need to list them as intermediates inputs so that we can now use them in `__call__`. -- A new component `mask_processor`: A default one will be created when we build the pipeline, but user can update it. -- Inside `__call__`, we created 2 new variables: the change map `diffdiff_mask` and the pre-computed noised latents for all timesteps `original_latents`. We also need to list them as intermediates outputs so the we can use them in the `denoise` step later. - -I have two tips I want to share for this process: -1. use `print(dd_pipeline.doc)` to check compiled inputs and outputs of the built piepline. +We'll start with the `prepare_latents` step, as it is the first step that gets called right after the `input` step. Let's first apply changes in inputs/outputs/components. The main changes are: +- new input `diffdiff_map` +- new intermediates inputs `num_inference_steps` and `timestesp`. Both variables are already created in `set_timesteps` block, we can now need to use them in `prepare_latents` step. +- A new component `mask_processor` to process the `diffdiff_map` + + + +💡 use `print(dd_pipeline.doc)` to check compiled inputs and outputs of the built piepline. + e.g. after we added `diffdiff_map` as an input in this step, we can run `print(dd_pipeline.doc)` to verify that it shows up in the docstring as a user input. -2. insert `print(state)` and `print(block_state)` everywhere inside the `__call__` method to inspect the intermediate results. + + + +Once we make sure all the variables we need are available in the block state, we can implement the diff-diff logic inside `__call__`. We created 2 new variables: the change map `diffdiff_mask` and the pre-computed noised latents for all timesteps `original_latents`. We also need to list them as intermediates outputs so the we can use them in the `denoise` step later. + + + +💡 Implement incrementally! Run the example script as you go, and insert `print(state)` and `print(block_state)` everywhere inside the `__call__` method to inspect the intermediate results. This helps you understand what's going on and what each line you just added does. + + This is the modified `StableDiffusionXLImg2ImgPrepareLatentsStep` we ended up with : ```diff @@ -265,7 +275,7 @@ This is the modified `StableDiffusionXLImg2ImgPrepareLatentsStep` we ended up wi self.add_block_state(state, block_state) ``` -This is the modified `before_denoiser` step, we use diff-diff map to freeze certain regions in the latents before each denoising step. +Now let's modify `before_denoiser` step, we use diff-diff map to freeze certain regions in the latents before each denoising step. ```diff class SDXLDiffDiffLoopBeforeDenoiser(PipelineBlock): @@ -324,10 +334,15 @@ class SDXLDiffDiffLoopBeforeDenoiser(PipelineBlock): return components, block_state ``` -That's all there is to it! Now your script should run as expected and get a result like this one. +That's all there is to it! We've just created a simple sequential pipeline by mix-and-match some existing and new pipeline blocks. + + + + +💡 You can inspect the pipeline you built with `print()` + + -Here is the pipeline we created ( hint, `print(dd_blocks)`) -It is a simple sequential pipeline. ```out SequentialPipelineBlocks( @@ -515,7 +530,7 @@ SequentialPipelineBlocks( ) ``` -Let's test it out. I used an orange image to condition the generation via ip-addapter and we can see a slight orange color and texture in the final output. +Let's test it out. We used an orange image to condition the generation via ip-addapter and we can see a slight orange color and texture in the final output. ```py @@ -551,8 +566,8 @@ Let's test it out. I used an orange image to condition the generation via ip-add ## Working with ControlNets What about controlnet? Can differential diffusion work with controlnet? The key differences between a regular pipeline and a ControlNet pipeline are: - * A ControlNet input step that prepares the control condition - * Inside the denoising loop, a modified denoiser step where the control image is first processed through ControlNet, then control information is injected into the UNet +1. A ControlNet input step that prepares the control condition +2. Inside the denoising loop, a modified denoiser step where the control image is first processed through ControlNet, then control information is injected into the UNet From looking at the code workflow: differential diffusion only modifies the "before denoiser" step, while ControlNet operates within the "denoiser" itself. Since they intervene at different points in the pipeline, they should work together without conflicts. @@ -569,7 +584,7 @@ With this understanding, let's assemble the `SDXLDiffDiffControlNetDenoiseLoop`: >>> # print(controlnet_denoise) ``` -We provide a auto controlnet input block that you can directly put into your workflow: similar to auto ip-adapter block, this step will only run if `control_image` input is passed from user. It work with both controlnet and controlnet union. +We provide a auto controlnet input block that you can directly put into your workflow to proceess the `control_image`: similar to auto ip-adapter block, this step will only run if `control_image` input is passed from user. It work with both controlnet and controlnet union. ```py @@ -609,7 +624,7 @@ StableDiffusionXLControlNetAutoInput( ) ``` -Let's assemble the blocks and run an example using controlnet + differential diffusion. I used a canny of a tomato as `control_image`, so you can see in the output, the right half that transformed into a pear had a tomato-like shape. +Let's assemble the blocks and run an example using controlnet + differential diffusion. We used a tomato as `control_image`, so you can see that in the output, the right half that transformed into a pear had a tomato-like shape. ```py >>> dd_blocks.blocks.insert("controlnet_input", control_input_block, 7) @@ -652,9 +667,13 @@ Optionally, We can combine `SDXLDiffDiffControlNetDenoiseLoop` and `SDXLDiffDiff `SDXLDiffDiffAutoDenoiseStep` will run the ControlNet denoise step if `control_image` input is provided, otherwise it will run the regular denoise step. -We won't go into too much detail about `AutoPipelineBlocks` in this section, but you can read more about it [here](TODO). Note that it's perfectly fine not to use `AutoPipelineBlocks`. In fact, we recommend only using `AutoPipelineBlocks` to package your workflow at the end once you've verified all your pipelines work as expected. + + + Note that it's perfectly fine not to use `AutoPipelineBlocks`. In fact, we recommend only using `AutoPipelineBlocks` to package your workflow at the end once you've verified all your pipelines work as expected. We won't go into too much detail about `AutoPipelineBlocks` in this section, but you can read more about it [here](TODO). + + -now you can create the differential diffusion preset that works with ip-adapter & controlnet. +Now you can create the differential diffusion preset that works with ip-adapter & controlnet. ```py >>> DIFFDIFF_AUTO_BLOCKS = IMAGE2IMAGE_BLOCKS.copy() From 7c78fb1aadc318fdabee6a829798f24c25860617 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 24 Jun 2025 08:16:34 +0200 Subject: [PATCH 081/170] add a overview doc page --- docs/source/en/modular_diffusers/overview.md | 206 +++++++++++++++++++ 1 file changed, 206 insertions(+) create mode 100644 docs/source/en/modular_diffusers/overview.md diff --git a/docs/source/en/modular_diffusers/overview.md b/docs/source/en/modular_diffusers/overview.md new file mode 100644 index 000000000000..8321430820c5 --- /dev/null +++ b/docs/source/en/modular_diffusers/overview.md @@ -0,0 +1,206 @@ + + +# Overview + +The Modular Diffusers Framework consist of three main components + +## ModularPipelineBlocks + +Pipeline blocks are the fundamental building blocks of the Modular Diffusers system. All pipeline blocks inherit from the base class `ModularPipelineBlocks`, including: +- [`PipelineBlock`](TODO) +- [`SequentialPipelineBlocks`](TODO) +- [`LoopSequentialPipelineBlocks`](TODO) +- [`AutoPipelineBlocks`](TODO) + + +Each block defines: + +**Specifications:** +- Inputs: User-provided parameters that the block expects +- Intermediate inputs: Variables from other blocks that this block needs +- Intermediate outputs: Variables this block produces for other blocks to use +- Components: Models and processors the block requires (e.g., UNet, VAE, scheduler) + +**Computation:** +- `__call__` method: Defines the actual computational steps within the block + +Pipeline blocks are essentially **"definitions"** - they define the specifications and computational steps for a pipeline, but are not runnable until converted into a `ModularPipeline` object. + +All blocks interact with a global `PipelineState` object that maintains the pipeline's state throughout execution. + +### Load/save a custom `ModularPipelineBlocks` + +You can load a custom pipeline block from a hub repository directly + +```py +from diffusers import ModularPipelineBlocks +diffdiff_block = ModularPipelineBlocks.from_pretrained(repo_id, trust_remote_code=True) +``` + +to save, and publish to a hub repository + +```py +diffdiff_block.save(repo_id) +``` + +## PipelineState & BlockState + +`PipelineState` and `BlockState` manage dataflow between pipeline blocks. `PipelineState` acts as the global state container that `ModularPipelineBlocks` operate on - each block gets a local view (`BlockState`) of the relevant variables it needs from `PipelineState`, performs its operations, and then updates PipelineState with any changes. + + + +You typically don't need to manually create or manage these state objects. The `ModularPipeline` automatically creates and manages them for you. However, understanding their roles is important for developing custom pipeline blocks. + + + +## ModularPipeline + +`ModularPipeline` is the main interface to create and execute pipelines in the Modular Diffusers system. + +### Create a `ModularPipeline` + +Each `ModularPipelineBlocks` has an `init_pipeline` method that can initialize a `ModularPipeline` object based on its component and configuration specifications. + +```py +>>> pipeline = blocks.init_pipeline(pretrained_model_name_or_path) +``` + +`ModularPipeline` only works with modular repositories, so make sure `pretrained_model_name_or_path` points to a modular repo (you can see an example [here](https://huggingface.co/YiYiXu/modular-diffdiff)). + +The main differences from standard diffusers repositories are: + +1. `modular_model_index.json` vs `model_index.json` + +In standard `model_index.json`, each component entry is a `(library, class)` tuple: + +```py +"text_encoder": [ + "transformers", + "CLIPTextModel" +], +``` + +In `modular_model_index.json`, each component entry contains 3 elements: `(library, class, loading_specs {})` + +- `library` and `class`: Information about the actual component loaded in the pipeline at the time of saving (can be `None` if not loaded) +- **`loading_specs`**: A dictionary containing all information required to load this component, including `repo`, `revision`, `subfolder`, `variant`, and `type_hint` + +```py +"text_encoder": [ + null, # library (same as model_index.json) + null, # class (same as model_index.json) + { # loading specs map (unique to modular_model_index.json) + "repo": "stabilityai/stable-diffusion-xl-base-1.0", # can be a different repo + "revision": null, + "subfolder": "text_encoder", + "type_hint": [ # (library, class) for the expected component class + "transformers", + "CLIPTextModel" + ], + "variant": null + } +], +``` + +2. Cross-Repository Component Loading + +Unlike standard repositories where components must be in subfolders within the same repo, modular repositories can fetch components from different repositories based on the `loading_specs` dictionary. In our example above, the `text_encoder` component will be fetched from the "text_encoder" folder in `stabilityai/stable-diffusion-xl-base-1.0` while other components come from different repositories. + + + + +💡 We recommend using `ModularPipeline` with Component Manager by passing a `components_manager`: + +```py +>>> components = ComponentsManager() +>>> pipeline = blocks.init_pipeline(pretrained_model_name_or_path, components_manager=components) +``` + +This helps you to: +1. Detect and manage duplicated models (warns when trying to register an existing model) +2. Easily reuse components across different pipelines +3. Apply offloading strategies across multiple pipelines + +You can read more about Components Manager [here](TODO) + + + + +Unlike `DiffusionPipeline`, you need to explicitly load model components using `load_components`: + +```py +>>> pipeline.load_components(torch_dtype=torch.float16) +>>> pipeline.to(device) +``` + +You can partially load specific components using the `component_names` argument, for example to only load unet and vae: + +```py +>>> pipeline.load_components(component_names=["unet", "vae"]) +``` + + + +💡 You can inspect the pipeline's `config` attribute (which contains the same structure as `modular_model_index.json` we just walked through) to check the "loading status" of the pipeline, e.g. what components this pipeline expects to load and their loading specs, what components are already loaded and their actual class & loading specs etc. + + + +### Execute a `ModularPipeline` + +The API to run the `ModularPipeline` is very similar to how you would run a regular `DiffusionPipeline`: + +```py +>>> image = pipeline(prompt="a cat", num_inference_steps=15, output="images")[0] +``` + +There are a few key differences though: +1. You can also pass a `PipelineState` object directly to the pipeline instead of individual arguments +2. If you do not specify the `output` argument, it returns the `PipelineState` object +3. You can pass a list as `output`, e.g. `pipeline(... output=["images", "latents"])` will return a dictionary containing both the generated image and the final denoised latents + +Under the hood, `ModularPipeline`'s `__call__` method is a wrapper around the pipeline blocks' `__call__` method: it creates a `PipelineState` object and populates it with user inputs, then returns the output to the user based on the `output` argument. It also ensures that all pipeline-level config and components are exposed to all pipeline blocks by preparing and passing a `components` input. + +### Load a `ModularPipeline` from hub + +You can directly load a `ModularPipeline` from a HuggingFace Hub repository, as long as it's a modular repo + +```py +pipeine = ModularPipeline.from_pretrained(repo_id, components_manager=..., collection=...) +``` + +Loading custom code is also supported, just pass a `trust_remote_code=True` argument: + +```py +diffdiff_pipeline = ModularPipeline.from_pretrained(repo_id, trust_remote_code=True, ...) +``` + +The ModularPipeine created with `from_pretrained` method also would not load any components and you would have to call `load_components` to explicitly load components you need. + + +### Save a `ModularPipeline` + +to save a `ModularPipeline` and publish it to hub + +```py +pipeline.save_pretrained("YiYiXu/modular-loader-t2i", push_to_hub=True) +``` + + + +We do not automatically save custom code and share it on hub for you, please read more about how to share your custom pipeline on hub [here](TODO: ModularPipeline/CustomCode) + + + + + + From 48e4ff5c05568130ad419c65fa7ce565283f665a Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 24 Jun 2025 10:17:35 +0200 Subject: [PATCH 082/170] update overview --- docs/source/en/_toctree.yml | 2 + docs/source/en/modular_diffusers/overview.md | 115 ++++++++++--------- 2 files changed, 63 insertions(+), 54 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index d8fc32f38093..858309391fce 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -89,6 +89,8 @@ - sections: - local: modular_diffusers/developer_guide title: Developer Guide + - local: modular_diffusers/overview + title: Overview - sections: - local: using-diffusers/cogvideox title: CogVideoX diff --git a/docs/source/en/modular_diffusers/overview.md b/docs/source/en/modular_diffusers/overview.md index 8321430820c5..ecb7d4cf1fea 100644 --- a/docs/source/en/modular_diffusers/overview.md +++ b/docs/source/en/modular_diffusers/overview.md @@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License. # Overview -The Modular Diffusers Framework consist of three main components +The Modular Diffusers Framework consists of three main components: ## ModularPipelineBlocks @@ -23,35 +23,37 @@ Pipeline blocks are the fundamental building blocks of the Modular Diffusers sys - [`AutoPipelineBlocks`](TODO) -Each block defines: - -**Specifications:** -- Inputs: User-provided parameters that the block expects -- Intermediate inputs: Variables from other blocks that this block needs -- Intermediate outputs: Variables this block produces for other blocks to use -- Components: Models and processors the block requires (e.g., UNet, VAE, scheduler) - -**Computation:** -- `__call__` method: Defines the actual computational steps within the block - -Pipeline blocks are essentially **"definitions"** - they define the specifications and computational steps for a pipeline, but are not runnable until converted into a `ModularPipeline` object. - -All blocks interact with a global `PipelineState` object that maintains the pipeline's state throughout execution. - -### Load/save a custom `ModularPipelineBlocks` - -You can load a custom pipeline block from a hub repository directly - +To use a `ModularPipelineBlocks` officially supported in 🧨 Diffusers ```py -from diffusers import ModularPipelineBlocks -diffdiff_block = ModularPipelineBlocks.from_pretrained(repo_id, trust_remote_code=True) +>>> from diffusers.modular_pipelines.stable_diffusion_xl import StableDiffusionXLTextEncoderStep +>>> text_encoder_block = StableDiffusionXLTextEncoderStep() ``` -to save, and publish to a hub repository +Each [`ModularPipelineBlocks`] defines its requirement for components, configs, inputs, intermediate inputs, and outputs. You'll see that this text encoder block uses text_encoders, tokenizers as well as a guider component. It takes user inputs such as `prompt` and `negative_prompt`, and return a list of conditional text embeddings. -```py -diffdiff_block.save(repo_id) ``` +>>> text_encoder_block +StableDiffusionXLTextEncoderStep( + Class: PipelineBlock + Description: Text Encoder step that generate text_embeddings to guide the image generation + Components: + text_encoder (`CLIPTextModel`) + text_encoder_2 (`CLIPTextModelWithProjection`) + tokenizer (`CLIPTokenizer`) + tokenizer_2 (`CLIPTokenizer`) + guider (`ClassifierFreeGuidance`) + Configs: + force_zeros_for_empty_prompt (default: True) + Inputs: + prompt=None, prompt_2=None, negative_prompt=None, negative_prompt_2=None, cross_attention_kwargs=None, clip_skip=None + Intermediates: + - outputs: prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds +) +``` + +Pipeline blocks are essentially **"definitions"** - they define the specifications and computational steps for a pipeline. However, they do not contain any model states, and are not runnable until converted into a `ModularPipeline` object. + +Read more about how to write your own `ModularPipelineBlocks` [here](TODO) ## PipelineState & BlockState @@ -67,15 +69,9 @@ You typically don't need to manually create or manage these state objects. The ` `ModularPipeline` is the main interface to create and execute pipelines in the Modular Diffusers system. -### Create a `ModularPipeline` +### Modular Repo -Each `ModularPipelineBlocks` has an `init_pipeline` method that can initialize a `ModularPipeline` object based on its component and configuration specifications. - -```py ->>> pipeline = blocks.init_pipeline(pretrained_model_name_or_path) -``` - -`ModularPipeline` only works with modular repositories, so make sure `pretrained_model_name_or_path` points to a modular repo (you can see an example [here](https://huggingface.co/YiYiXu/modular-diffdiff)). +`ModularPipeline` only works with modular repositories. You can find an example modular repo [here](https://huggingface.co/YiYiXu/modular-diffdiff). The main differences from standard diffusers repositories are: @@ -93,7 +89,7 @@ In standard `model_index.json`, each component entry is a `(library, class)` tup In `modular_model_index.json`, each component entry contains 3 elements: `(library, class, loading_specs {})` - `library` and `class`: Information about the actual component loaded in the pipeline at the time of saving (can be `None` if not loaded) -- **`loading_specs`**: A dictionary containing all information required to load this component, including `repo`, `revision`, `subfolder`, `variant`, and `type_hint` +- `loading_specs`: A dictionary containing all information required to load this component, including `repo`, `revision`, `subfolder`, `variant`, and `type_hint` ```py "text_encoder": [ @@ -114,7 +110,16 @@ In `modular_model_index.json`, each component entry contains 3 elements: `(libra 2. Cross-Repository Component Loading -Unlike standard repositories where components must be in subfolders within the same repo, modular repositories can fetch components from different repositories based on the `loading_specs` dictionary. In our example above, the `text_encoder` component will be fetched from the "text_encoder" folder in `stabilityai/stable-diffusion-xl-base-1.0` while other components come from different repositories. +Unlike standard repositories where components must be in subfolders within the same repo, modular repositories can fetch components from different repositories based on the `loading_specs` dictionary. e.g. the `text_encoder` component will be fetched from the "text_encoder" folder in `stabilityai/stable-diffusion-xl-base-1.0` while other components come from different repositories. + + +### Create a `ModularPipeline` from `ModularPipelineBlocks` + +Each `ModularPipelineBlocks` has an `init_pipeline` method that can initialize a `ModularPipeline` object based on its component and configuration specifications. + +```py +>>> pipeline = blocks.init_pipeline(pretrained_model_name_or_path) +``` @@ -135,7 +140,6 @@ You can read more about Components Manager [here](TODO) - Unlike `DiffusionPipeline`, you need to explicitly load model components using `load_components`: ```py @@ -155,41 +159,42 @@ You can partially load specific components using the `component_names` argument, -### Execute a `ModularPipeline` +### Load a `ModularPipeline` from hub -The API to run the `ModularPipeline` is very similar to how you would run a regular `DiffusionPipeline`: +You can create a `ModularPipeline` from a HuggingFace Hub repository with `from_pretrained` method, as long as it's a modular repo: ```py ->>> image = pipeline(prompt="a cat", num_inference_steps=15, output="images")[0] +pipeline = ModularPipeline.from_pretrained(repo_id, components_manager=..., collection=...) ``` -There are a few key differences though: -1. You can also pass a `PipelineState` object directly to the pipeline instead of individual arguments -2. If you do not specify the `output` argument, it returns the `PipelineState` object -3. You can pass a list as `output`, e.g. `pipeline(... output=["images", "latents"])` will return a dictionary containing both the generated image and the final denoised latents +Loading custom code is also supported: -Under the hood, `ModularPipeline`'s `__call__` method is a wrapper around the pipeline blocks' `__call__` method: it creates a `PipelineState` object and populates it with user inputs, then returns the output to the user based on the `output` argument. It also ensures that all pipeline-level config and components are exposed to all pipeline blocks by preparing and passing a `components` input. +```py +diffdiff_pipeline = ModularPipeline.from_pretrained(repo_id, trust_remote_code=True, ...) +``` -### Load a `ModularPipeline` from hub +Similar to `init_pipeline` method, the modular pipeline will not load any components automatically, so you will have to call `load_components` to explicitly load the components you need. -You can directly load a `ModularPipeline` from a HuggingFace Hub repository, as long as it's a modular repo -```py -pipeine = ModularPipeline.from_pretrained(repo_id, components_manager=..., collection=...) -``` +### Execute a `ModularPipeline` -Loading custom code is also supported, just pass a `trust_remote_code=True` argument: +The API to run the `ModularPipeline` is very similar to how you would run a regular `DiffusionPipeline`: ```py -diffdiff_pipeline = ModularPipeline.from_pretrained(repo_id, trust_remote_code=True, ...) +>>> image = pipeline(prompt="a cat", num_inference_steps=15, output="images")[0] ``` -The ModularPipeine created with `from_pretrained` method also would not load any components and you would have to call `load_components` to explicitly load components you need. +There are a few key differences though: +1. You can also pass a `PipelineState` object directly to the pipeline instead of individual arguments +2. If you do not specify the `output` argument, it returns the `PipelineState` object +3. You can pass a list as `output`, e.g. `pipeline(... output=["images", "latents"])` will return a dictionary containing both the generated image and the final denoised latents + +Under the hood, `ModularPipeline`'s `__call__` method is a wrapper around the pipeline blocks' `__call__` method: it creates a `PipelineState` object and populates it with user inputs, then returns the output to the user based on the `output` argument. It also ensures that all pipeline-level config and components are exposed to all pipeline blocks by preparing and passing a `components` input. ### Save a `ModularPipeline` -to save a `ModularPipeline` and publish it to hub +To save a `ModularPipeline` and publish it to hub: ```py pipeline.save_pretrained("YiYiXu/modular-loader-t2i", push_to_hub=True) @@ -197,7 +202,9 @@ pipeline.save_pretrained("YiYiXu/modular-loader-t2i", push_to_hub=True) -We do not automatically save custom code and share it on hub for you, please read more about how to share your custom pipeline on hub [here](TODO: ModularPipeline/CustomCode) +We do not automatically save custom code and share it on hub for you. Please read more about how to share your custom pipeline on hub [here](TODO: ModularPipeline/CustomCode) + + From e49413d87d09ac0bc266fd4788df41e22a64fd7d Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 25 Jun 2025 08:52:15 +0200 Subject: [PATCH 083/170] update doc --- docs/source/en/_toctree.yml | 4 +- docs/source/en/modular_diffusers/overview.md | 213 ----- docs/source/en/modular_diffusers/quicktour.md | 871 ++++++++++++++++++ 3 files changed, 873 insertions(+), 215 deletions(-) delete mode 100644 docs/source/en/modular_diffusers/overview.md create mode 100644 docs/source/en/modular_diffusers/quicktour.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 858309391fce..e940e60b1151 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -89,8 +89,8 @@ - sections: - local: modular_diffusers/developer_guide title: Developer Guide - - local: modular_diffusers/overview - title: Overview + - local: modular_diffusers/quicktour + title: Quicktour - sections: - local: using-diffusers/cogvideox title: CogVideoX diff --git a/docs/source/en/modular_diffusers/overview.md b/docs/source/en/modular_diffusers/overview.md deleted file mode 100644 index ecb7d4cf1fea..000000000000 --- a/docs/source/en/modular_diffusers/overview.md +++ /dev/null @@ -1,213 +0,0 @@ - - -# Overview - -The Modular Diffusers Framework consists of three main components: - -## ModularPipelineBlocks - -Pipeline blocks are the fundamental building blocks of the Modular Diffusers system. All pipeline blocks inherit from the base class `ModularPipelineBlocks`, including: -- [`PipelineBlock`](TODO) -- [`SequentialPipelineBlocks`](TODO) -- [`LoopSequentialPipelineBlocks`](TODO) -- [`AutoPipelineBlocks`](TODO) - - -To use a `ModularPipelineBlocks` officially supported in 🧨 Diffusers -```py ->>> from diffusers.modular_pipelines.stable_diffusion_xl import StableDiffusionXLTextEncoderStep ->>> text_encoder_block = StableDiffusionXLTextEncoderStep() -``` - -Each [`ModularPipelineBlocks`] defines its requirement for components, configs, inputs, intermediate inputs, and outputs. You'll see that this text encoder block uses text_encoders, tokenizers as well as a guider component. It takes user inputs such as `prompt` and `negative_prompt`, and return a list of conditional text embeddings. - -``` ->>> text_encoder_block -StableDiffusionXLTextEncoderStep( - Class: PipelineBlock - Description: Text Encoder step that generate text_embeddings to guide the image generation - Components: - text_encoder (`CLIPTextModel`) - text_encoder_2 (`CLIPTextModelWithProjection`) - tokenizer (`CLIPTokenizer`) - tokenizer_2 (`CLIPTokenizer`) - guider (`ClassifierFreeGuidance`) - Configs: - force_zeros_for_empty_prompt (default: True) - Inputs: - prompt=None, prompt_2=None, negative_prompt=None, negative_prompt_2=None, cross_attention_kwargs=None, clip_skip=None - Intermediates: - - outputs: prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds -) -``` - -Pipeline blocks are essentially **"definitions"** - they define the specifications and computational steps for a pipeline. However, they do not contain any model states, and are not runnable until converted into a `ModularPipeline` object. - -Read more about how to write your own `ModularPipelineBlocks` [here](TODO) - -## PipelineState & BlockState - -`PipelineState` and `BlockState` manage dataflow between pipeline blocks. `PipelineState` acts as the global state container that `ModularPipelineBlocks` operate on - each block gets a local view (`BlockState`) of the relevant variables it needs from `PipelineState`, performs its operations, and then updates PipelineState with any changes. - - - -You typically don't need to manually create or manage these state objects. The `ModularPipeline` automatically creates and manages them for you. However, understanding their roles is important for developing custom pipeline blocks. - - - -## ModularPipeline - -`ModularPipeline` is the main interface to create and execute pipelines in the Modular Diffusers system. - -### Modular Repo - -`ModularPipeline` only works with modular repositories. You can find an example modular repo [here](https://huggingface.co/YiYiXu/modular-diffdiff). - -The main differences from standard diffusers repositories are: - -1. `modular_model_index.json` vs `model_index.json` - -In standard `model_index.json`, each component entry is a `(library, class)` tuple: - -```py -"text_encoder": [ - "transformers", - "CLIPTextModel" -], -``` - -In `modular_model_index.json`, each component entry contains 3 elements: `(library, class, loading_specs {})` - -- `library` and `class`: Information about the actual component loaded in the pipeline at the time of saving (can be `None` if not loaded) -- `loading_specs`: A dictionary containing all information required to load this component, including `repo`, `revision`, `subfolder`, `variant`, and `type_hint` - -```py -"text_encoder": [ - null, # library (same as model_index.json) - null, # class (same as model_index.json) - { # loading specs map (unique to modular_model_index.json) - "repo": "stabilityai/stable-diffusion-xl-base-1.0", # can be a different repo - "revision": null, - "subfolder": "text_encoder", - "type_hint": [ # (library, class) for the expected component class - "transformers", - "CLIPTextModel" - ], - "variant": null - } -], -``` - -2. Cross-Repository Component Loading - -Unlike standard repositories where components must be in subfolders within the same repo, modular repositories can fetch components from different repositories based on the `loading_specs` dictionary. e.g. the `text_encoder` component will be fetched from the "text_encoder" folder in `stabilityai/stable-diffusion-xl-base-1.0` while other components come from different repositories. - - -### Create a `ModularPipeline` from `ModularPipelineBlocks` - -Each `ModularPipelineBlocks` has an `init_pipeline` method that can initialize a `ModularPipeline` object based on its component and configuration specifications. - -```py ->>> pipeline = blocks.init_pipeline(pretrained_model_name_or_path) -``` - - - - -💡 We recommend using `ModularPipeline` with Component Manager by passing a `components_manager`: - -```py ->>> components = ComponentsManager() ->>> pipeline = blocks.init_pipeline(pretrained_model_name_or_path, components_manager=components) -``` - -This helps you to: -1. Detect and manage duplicated models (warns when trying to register an existing model) -2. Easily reuse components across different pipelines -3. Apply offloading strategies across multiple pipelines - -You can read more about Components Manager [here](TODO) - - - -Unlike `DiffusionPipeline`, you need to explicitly load model components using `load_components`: - -```py ->>> pipeline.load_components(torch_dtype=torch.float16) ->>> pipeline.to(device) -``` - -You can partially load specific components using the `component_names` argument, for example to only load unet and vae: - -```py ->>> pipeline.load_components(component_names=["unet", "vae"]) -``` - - - -💡 You can inspect the pipeline's `config` attribute (which contains the same structure as `modular_model_index.json` we just walked through) to check the "loading status" of the pipeline, e.g. what components this pipeline expects to load and their loading specs, what components are already loaded and their actual class & loading specs etc. - - - -### Load a `ModularPipeline` from hub - -You can create a `ModularPipeline` from a HuggingFace Hub repository with `from_pretrained` method, as long as it's a modular repo: - -```py -pipeline = ModularPipeline.from_pretrained(repo_id, components_manager=..., collection=...) -``` - -Loading custom code is also supported: - -```py -diffdiff_pipeline = ModularPipeline.from_pretrained(repo_id, trust_remote_code=True, ...) -``` - -Similar to `init_pipeline` method, the modular pipeline will not load any components automatically, so you will have to call `load_components` to explicitly load the components you need. - - -### Execute a `ModularPipeline` - -The API to run the `ModularPipeline` is very similar to how you would run a regular `DiffusionPipeline`: - -```py ->>> image = pipeline(prompt="a cat", num_inference_steps=15, output="images")[0] -``` - -There are a few key differences though: -1. You can also pass a `PipelineState` object directly to the pipeline instead of individual arguments -2. If you do not specify the `output` argument, it returns the `PipelineState` object -3. You can pass a list as `output`, e.g. `pipeline(... output=["images", "latents"])` will return a dictionary containing both the generated image and the final denoised latents - -Under the hood, `ModularPipeline`'s `__call__` method is a wrapper around the pipeline blocks' `__call__` method: it creates a `PipelineState` object and populates it with user inputs, then returns the output to the user based on the `output` argument. It also ensures that all pipeline-level config and components are exposed to all pipeline blocks by preparing and passing a `components` input. - - -### Save a `ModularPipeline` - -To save a `ModularPipeline` and publish it to hub: - -```py -pipeline.save_pretrained("YiYiXu/modular-loader-t2i", push_to_hub=True) -``` - - - -We do not automatically save custom code and share it on hub for you. Please read more about how to share your custom pipeline on hub [here](TODO: ModularPipeline/CustomCode) - - - - - - - - diff --git a/docs/source/en/modular_diffusers/quicktour.md b/docs/source/en/modular_diffusers/quicktour.md new file mode 100644 index 000000000000..e8ef62403fd1 --- /dev/null +++ b/docs/source/en/modular_diffusers/quicktour.md @@ -0,0 +1,871 @@ + + +# Getting Started with Modular Diffusers + +With Modular Diffusers, we introduce a unified pipeline system that simplifies how you work with diffusion models. Instead of creating separate pipelines for each task, Modular Diffusers let you: + +**Write Only What's New**: You won't need to rewrite the entire pipeline from scratch. You can create pipeline blocks just for your new workflow's unique aspects and reuse existing blocks for existing functionalities. + +**Assemble Like LEGO®**: You can mix and match blocks in flexible ways. This allows you to write dedicated blocks for specific workflows, and then assemble different blocks into a pipeline that that can be used more conveniently for multiple workflows. + +In this guide, we will focus on how to use pipeline like this we built with Modular diffusers 🧨! We will also go over the basics of pipeline blocks, how they work under the hood, and how to assemble SequentialPipelineBlocks and AutoPipelineBlocks in this [guide](TODO). For advanced users who want to build complete workflows from scratch, we provide an end-to-end example in the [Developer Guide](developer_guide.md) that covers everything from writing custom pipeline blocks to deploying your workflow as a UI node. + +Let's get started! The Modular Diffusers Framework consists of three main components: + +## ModularPipelineBlocks + +Pipeline blocks are the fundamental building blocks of the Modular Diffusers system. All pipeline blocks inherit from the base class `ModularPipelineBlocks`, including: +- [`PipelineBlock`](TODO) +- [`SequentialPipelineBlocks`](TODO) +- [`LoopSequentialPipelineBlocks`](TODO) +- [`AutoPipelineBlocks`](TODO) + + +To use a `ModularPipelineBlocks` officially supported in 🧨 Diffusers +```py +>>> from diffusers.modular_pipelines.stable_diffusion_xl import StableDiffusionXLTextEncoderStep +>>> text_encoder_block = StableDiffusionXLTextEncoderStep() +``` + +Each [`ModularPipelineBlocks`] defines its requirement for components, configs, inputs, intermediate inputs, and outputs. You'll see that this text encoder block uses 2 text_encoders, 2 tokenizers as well as a guider component. It takes user inputs such as `prompt` and `negative_prompt`, and return text embeddings such as `prompt_embeds` and `negative_prompt_embeds`. + +``` +>>> text_encoder_block +StableDiffusionXLTextEncoderStep( + Class: PipelineBlock + Description: Text Encoder step that generate text_embeddings to guide the image generation + Components: + text_encoder (`CLIPTextModel`) + text_encoder_2 (`CLIPTextModelWithProjection`) + tokenizer (`CLIPTokenizer`) + tokenizer_2 (`CLIPTokenizer`) + guider (`ClassifierFreeGuidance`) + Configs: + force_zeros_for_empty_prompt (default: True) + Inputs: + prompt=None, prompt_2=None, negative_prompt=None, negative_prompt_2=None, cross_attention_kwargs=None, clip_skip=None + Intermediates: + - outputs: prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds +) +``` + +More commonly, you can create a `SequentialPipelineBlocks` using a modular blocks preset officially supported in 🧨 Diffusers. + + +```py +from diffusers.modular_pipelines import SequentialPipelineBlocks +from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS +t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS) +``` + +This creates a text-to-image pipeline. + +```py +>>> t2i_blocks +SequentialPipelineBlocks( + Class: ModularPipelineBlocks + + Description: + + + Components: + text_encoder (`CLIPTextModel`) + text_encoder_2 (`CLIPTextModelWithProjection`) + tokenizer (`CLIPTokenizer`) + tokenizer_2 (`CLIPTokenizer`) + guider (`ClassifierFreeGuidance`) + scheduler (`EulerDiscreteScheduler`) + unet (`UNet2DConditionModel`) + vae (`AutoencoderKL`) + image_processor (`VaeImageProcessor`) + + Configs: + force_zeros_for_empty_prompt (default: True) + + Blocks: + [0] text_encoder (StableDiffusionXLTextEncoderStep) + Description: Text Encoder step that generate text_embeddings to guide the image generation + + [1] input (StableDiffusionXLInputStep) + Description: Input processing step that: + 1. Determines `batch_size` and `dtype` based on `prompt_embeds` + 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt` + + All input tensors are expected to have either batch_size=1 or match the batch_size + of prompt_embeds. The tensors will be duplicated across the batch dimension to + have a final batch_size of batch_size * num_images_per_prompt. + + [2] set_timesteps (StableDiffusionXLSetTimestepsStep) + Description: Step that sets the scheduler's timesteps for inference + + [3] prepare_latents (StableDiffusionXLPrepareLatentsStep) + Description: Prepare latents step that prepares the latents for the text-to-image generation process + + [4] prepare_add_cond (StableDiffusionXLPrepareAdditionalConditioningStep) + Description: Step that prepares the additional conditioning for the text-to-image generation process + + [5] denoise (StableDiffusionXLDenoiseLoop) + Description: Denoise step that iteratively denoise the latents. + Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method + At each iteration, it runs blocks defined in `blocks` sequencially: + - `StableDiffusionXLLoopBeforeDenoiser` + - `StableDiffusionXLLoopDenoiser` + - `StableDiffusionXLLoopAfterDenoiser` + + + [6] decode (StableDiffusionXLDecodeStep) + Description: Step that decodes the denoised latents into images + +) +``` + +The blocks preset we used (`TEXT2IMAGE_BLOCKS`) is just a dictionary that maps names to ModularPipelineBlocks classes + +```py +>>> TEXT2IMAGE_BLOCKS +InsertableOrderedDict([ + 0: ('text_encoder', ), + 1: ('input', ), + 2: ('set_timesteps', ), + 3: ('prepare_latents', ), + 4: ('prepare_add_cond', ), + 5: ('denoise', ), + 6: ('decode', ) +]) +``` + +When we create a `SequentialPipelineBlocks` from this preset, it instantiates each class into actual block objects. Its `blocks` attribute contains these instantiated objects: + +```py +>>> t2i_blocks.blocks +InsertableOrderedDict([ + 0: ('text_encoder', ), + 1: ('input', ), + 2: ('set_timesteps', ), + 3: ('prepare_latents', ), + 4: ('prepare_add_cond', ), + 5: ('denoise', ), + 6: ('decode', ) +]) +``` + +Note that both the preset and the `blocks` attribute are `InsertableOrderedDict` objects, which allows you to modify them in several ways: + +**Add a block/block_class at specific positions:** +```py +# Add to preset (class) +BLOCKS.insert("block_name", BlockClass, index) +# Add to blocks attribute (instance) +t2i_blocks.blocks.insert("block_name", block_instance, index) +``` + +**Remove blocks:** +```py +# remove a block class from preset +BLOCKS.pop("text_encoder") +# split out a block instance on its own +text_encoder_block = t2i_blocks.blocks.pop("text_encoder") +``` + +**Swap/replace blocks:** +```py +# Replace in preset (class) +BLOCKS["prepare_latents"] = CustomPrepareLatents +# Replace in blocks attribute (instance) +t2i_blocks.blocks["prepare_latents"] = CustomPrepareLatents() +``` + +This means you can mix-and-match blocks in very flexible ways. Let's see some real examples: + +**Example 1: Adding IP-Adapter to the preset** +Let's insert IP-Adapter at index 0 (before the text_encoder block) to create a text-to-image pipeline with IP-Adapter support: + +```py +from diffusers.modular_pipelines.stable_diffusion_xl import StableDiffusionXLAutoIPAdapterStep +CUSTOM_BLOCKS = TEXT2IMAGE_BLOCKS.copy() +CUSTOM_BLOCKS.insert("ip_adapter", StableDiffusionXLAutoIPAdapterStep, 0) +custom_blocks = SequentialPipelineBlocks.from_blocks_dict(CUSTOM_BLOCKS) +``` + +**Example 2: Extracting a block from the pipeline** +You can extract a block instance from the pipeline to use it independently. A common pattern is to extract the text_encoder to process prompts once, then reuse the text embeddings to generate multiple images with different settings (schedulers, seeds, inference steps). + +```py +>>> text_encoder_blocks = t2i_blocks.blocks.pop("text_encoder") +>>> text_encoder_blocks +StableDiffusionXLTextEncoderStep( + Class: PipelineBlock + Description: Text Encoder step that generate text_embeddings to guide the image generation + Components: + text_encoder (`CLIPTextModel`) + text_encoder_2 (`CLIPTextModelWithProjection`) + tokenizer (`CLIPTokenizer`) + tokenizer_2 (`CLIPTokenizer`) + guider (`ClassifierFreeGuidance`) + Configs: + force_zeros_for_empty_prompt (default: True) + Inputs: + prompt=None, prompt_2=None, negative_prompt=None, negative_prompt_2=None, cross_attention_kwargs=None, clip_skip=None + Intermediates: + - outputs: prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds +) +``` + +the pipeline now has fewer components and no longer has the `text_encoder` block: + +```py +>>> t2i_blocks +SequentialPipelineBlocks( + Class: ModularPipelineBlocks + + Description: + + Components: + scheduler (`EulerDiscreteScheduler`) + guider (`ClassifierFreeGuidance`) + unet (`UNet2DConditionModel`) + vae (`AutoencoderKL`) + image_processor (`VaeImageProcessor`) + + Blocks: + [0] input (StableDiffusionXLInputStep) + Description: Input processing step that: + 1. Determines `batch_size` and `dtype` based on `prompt_embeds` + 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt` + + All input tensors are expected to have either batch_size=1 or match the batch_size + of prompt_embeds. The tensors will be duplicated across the batch dimension to + have a final batch_size of batch_size * num_images_per_prompt. + + [1] set_timesteps (StableDiffusionXLSetTimestepsStep) + Description: Step that sets the scheduler's timesteps for inference + + [2] prepare_latents (StableDiffusionXLPrepareLatentsStep) + Description: Prepare latents step that prepares the latents for the text-to-image generation process + + [3] prepare_add_cond (StableDiffusionXLPrepareAdditionalConditioningStep) + Description: Step that prepares the additional conditioning for the text-to-image generation process + + [4] denoise (StableDiffusionXLDenoiseLoop) + Description: Denoise step that iteratively denoise the latents. + Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method + At each iteration, it runs blocks defined in `blocks` sequencially: + - `StableDiffusionXLLoopBeforeDenoiser` + - `StableDiffusionXLLoopDenoiser` + - `StableDiffusionXLLoopAfterDenoiser` + + + [5] decode (StableDiffusionXLDecodeStep) + Description: Step that decodes the denoised latents into images + +) +``` + +We will not go over how to write your own ModularPipelineBlocks but you can learn more about it [here](TODO). + +This covers the essentials of pipeline blocks! You may have noticed that we haven't discussed how to load or run pipeline blocks - that's because **pipeline blocks are not runnable by themselves**. They are essentially **"definitions"** - they define the specifications and computational steps for a pipeline, but they do not contain any model states. To actually run them, you need to convert them into a `ModularPipeline` object. + +## PipelineState & BlockState + +`PipelineState` and `BlockState` manage dataflow between pipeline blocks. `PipelineState` acts as the global state container that `ModularPipelineBlocks` operate on - each block gets a local view (`BlockState`) of the relevant variables it needs from `PipelineState`, performs its operations, and then updates PipelineState with any changes. + + + +You typically don't need to manually create or manage these state objects. The `ModularPipeline` automatically creates and manages them for you. However, understanding their roles is important for developing custom pipeline blocks. + + + +## ModularPipeline + +`ModularPipeline` is the main interface to create and execute pipelines in the Modular Diffusers system. + +### Modular Repo + +`ModularPipeline` only works with modular repositories. You can find an example modular repo [here](https://huggingface.co/YiYiXu/modular-diffdiff). + +Instead of using a `model_index.json` to configure components loading in `DiffusionPipeline`. Modular repositories work with `modular_model_index.json`. Let's walk through the difference here. + +In standard `model_index.json`, each component entry is a `(library, class)` tuple: + +```py +"text_encoder": [ + "transformers", + "CLIPTextModel" +], +``` + +In `modular_model_index.json`, each component entry contains 3 elements: `(library, class, loading_specs {})` + +- `library` and `class`: Information about the actual component loaded in the pipeline at the time of saving (can be `None` if not loaded) +- `loading_specs`: A dictionary containing all information required to load this component, including `repo`, `revision`, `subfolder`, `variant`, and `type_hint` + +```py +"text_encoder": [ + null, # library (same as model_index.json) + null, # class (same as model_index.json) + { # loading specs map (unique to modular_model_index.json) + "repo": "stabilityai/stable-diffusion-xl-base-1.0", # can be a different repo + "revision": null, + "subfolder": "text_encoder", + "type_hint": [ # (library, class) for the expected component class + "transformers", + "CLIPTextModel" + ], + "variant": null + } +], +``` + +Unlike standard repositories where components must be in subfolders within the same repo, modular repositories can fetch components from different repositories based on the `loading_specs` dictionary. e.g. the `text_encoder` component will be fetched from the "text_encoder" folder in `stabilityai/stable-diffusion-xl-base-1.0` while other components come from different repositories. + + +### Creating a `ModularPipeline` from `ModularPipelineBlocks` + +Each `ModularPipelineBlocks` has an `init_pipeline` method that can initialize a `ModularPipeline` object based on its component and configuration specifications. + +Let's convert our `t2i_blocks` (which we created earlier) into a runnable `ModularPipeline`: + +```py +# We already have this from earlier +t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS) + +# Now convert it to a ModularPipeline +modular_repo_id = "YiYiXu/modular-loader-t2i" +t2i_pipeline = t2i_blocks.init_pipeline(modular_repo_id) +``` + + + +💡 We recommend using `ModularPipeline` with Component Manager by passing a `components_manager`: + +```py +>>> components = ComponentsManager() +>>> pipeline = blocks.init_pipeline(modular_repo_id, components_manager=components) +``` + +This helps you to: +1. Detect and manage duplicated models (warns when trying to register an existing model) +2. Easily reuse components across different pipelines +3. Apply offloading strategies across multiple pipelines + +You can read more about Components Manager [here](TODO) + + + + +### Creating a `ModularPipeline` with `from_pretrained` + +You can create a `ModularPipeline` from a HuggingFace Hub repository with `from_pretrained` method, as long as it's a modular repo: + +```py +# YiYi TODO: this is not yet supported actually 😢, need to add support +from diffusers import ModularPipeline +pipeline = ModularPipeline.from_pretrained(repo_id, components_manager=..., collection=...) +``` + +Loading custom code is also supported: + +```py +from diffusers import ModularPipeline +modular_repo_id = "YiYiXu/modular-diffdiff" +diffdiff_pipeline = ModularPipeline.from_pretrained(modular_repo_id, trust_remote_code=True) +``` + +### Loading components into a `ModularPipeline` + +Unlike `DiffusionPipeline`, when you create a `ModularPipeline` instance (whether using `from_pretrained` or converting from pipeline blocks), its components aren't loaded automatically. You need to explicitly load model components using `load_components`: + +```py +# This will load ALL the expected components into pipeline +t2i_pipeline.load_components(torch_dtype=torch.float16) +t2i_pipeline.to(device) +``` + +All expected components are now loaded into the pipeline. You can also partially load specific components using the `component_names` argument. For example, to only load unet and vae: + +```py +>>> t2i_pipeline.load_components(component_names=["unet", "vae"]) +``` + +You can inspect the pipeline's loading status through its `loader` attribute to understand what components are expected to load, which ones are already loaded, how they were loaded, and what loading specs are available. It has the same structure as the `modular_model_index.json` we discussed earlier - each component entry contains the `(library, class, loading_specs)` format. You'll need to understand that structure to properly read the loading status below. + +Let's inspect the `t2i_pipeline`, you can see all the components expected to load are listed as entries in the loader. The `guider` and `image_processor` components were created using default config (their `library` and `class` field are populated, this means they are initialized, but `loading_spec["repo"]` is null). The `vae` and `unet` components were loaded using their respective loading specs. The rest of the components (scheduler, text_encoder, text_encoder_2, tokenizer, tokenizer_2) are not loaded yet (their `library`, `class` fields are `null`), but you can examine their loading specs to see where they would be loaded from when you call `load_components()`. + +```py +>>> t2i_pipeline.loader +StableDiffusionXLModularLoader { + "_class_name": "StableDiffusionXLModularLoader", + "_diffusers_version": "0.34.0.dev0", + "force_zeros_for_empty_prompt": true, + "guider": [ + "diffusers", + "ClassifierFreeGuidance", + { + "repo": null, + "revision": null, + "subfolder": null, + "type_hint": [ + "diffusers", + "ClassifierFreeGuidance" + ], + "variant": null + } + ], + "image_processor": [ + "diffusers", + "VaeImageProcessor", + { + "repo": null, + "revision": null, + "subfolder": null, + "type_hint": [ + "diffusers", + "VaeImageProcessor" + ], + "variant": null + } + ], + "scheduler": [ + null, + null, + { + "repo": "stabilityai/stable-diffusion-xl-base-1.0", + "revision": null, + "subfolder": "scheduler", + "type_hint": [ + "diffusers", + "EulerDiscreteScheduler" + ], + "variant": null + } + ], + "text_encoder": [ + null, + null, + { + "repo": "stabilityai/stable-diffusion-xl-base-1.0", + "revision": null, + "subfolder": "text_encoder", + "type_hint": [ + "transformers", + "CLIPTextModel" + ], + "variant": null + } + ], + "text_encoder_2": [ + null, + null, + { + "repo": "stabilityai/stable-diffusion-xl-base-1.0", + "revision": null, + "subfolder": "text_encoder_2", + "type_hint": [ + "transformers", + "CLIPTextModelWithProjection" + ], + "variant": null + } + ], + "tokenizer": [ + null, + null, + { + "repo": "stabilityai/stable-diffusion-xl-base-1.0", + "revision": null, + "subfolder": "tokenizer", + "type_hint": [ + "transformers", + "CLIPTokenizer" + ], + "variant": null + } + ], + "tokenizer_2": [ + null, + null, + { + "repo": "stabilityai/stable-diffusion-xl-base-1.0", + "revision": null, + "subfolder": "tokenizer_2", + "type_hint": [ + "transformers", + "CLIPTokenizer" + ], + "variant": null + } + ], + "unet": [ + "diffusers", + "UNet2DConditionModel", + { + "repo": "RunDiffusion/Juggernaut-XL-v9", + "revision": null, + "subfolder": "unet", + "type_hint": [ + "diffusers", + "UNet2DConditionModel" + ], + "variant": "fp16" + } + ], + "vae": [ + "diffusers", + "AutoencoderKL", + { + "repo": "madebyollin/sdxl-vae-fp16-fix", + "revision": null, + "subfolder": null, + "type_hint": [ + "diffusers", + "AutoencoderKL" + ], + "variant": null + } + ] +} +``` +### Updating components in a `ModularPipeline` + +Similar to `DiffusionPipeline`, You could load an components separately to replace the default one in the pipeline. But in Modular Diffusers system, you need to use `ComponentSpec` to load/create them. + +`ComponentSpec` defines how to create or load components and can actually create them using its `create()` method (for ConfigMixin objects) or `load()` method (wrapper around `from_pretrained()`). When a component is loaded with a ComponentSpec, it gets tagged with a unique ID that encodes its creation parameters, allowing you to always extract the original specification using `ComponentSpec.from_component()`. In Modular Diffusers, all pretrained models should be loaded using `ComponentSpec` objects. + +So instead of + +```py +from diffusers import UNet2DConditionModel +import torch +unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", variant="fp16", torch_dtype=torch.float16) +``` +You should do + +```py +from diffusers import ComponentSpec, UNet2DConditionModel +unet_spec = ComponentSpec(name="unet",type_hint=UNet2DConditionModel, repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", variant="fp16") +unet2 = unet_spec.load(torch_dtype=torch.float16) +``` + +The key difference is that the second unet (the one we load with `ComponentSpec`) retains its loading specs, so you can extract and recreate it: + +```py +# to extract spec, you can do spec.load() to recreate it +>>> spec = ComponentSpec.from_component("unet", unet2) +>>> spec +ComponentSpec(name='unet', type_hint=, description=None, config=None, repo='stabilityai/stable-diffusion-xl-base-1.0', subfolder='unet', variant='fp16', revision=None, default_creation_method='from_pretrained') +``` + +To replace the unet in the pipeline + +``` +t2i_pipeline.update_components(unet=unet2) +``` + +Not only is the `unet` component swapped, but its loading specs are also updated from "RunDiffusion/Juggernaut-XL-v9" to "stabilityai/stable-diffusion-xl-base-1.0". This means that if you save the pipeline now and load it back with `from_pretrained`, the new pipeline will by default load the SDXL original unet. + +``` +>>> t2i_pipeline.loader +StableDiffusionXLModularLoader { + ... + "unet": [ + "diffusers", + "UNet2DConditionModel", + { + "repo": "stabilityai/stable-diffusion-xl-base-1.0", + "revision": null, + "subfolder": "unet", + "type_hint": [ + "diffusers", + "UNet2DConditionModel" + ], + "variant": "fp16" + } + ], + ... +} +``` + + +### Run a `ModularPipeline` + +The API to run the `ModularPipeline` is very similar to how you would run a regular `DiffusionPipeline`: + +```py +>>> image = pipeline(prompt="a cat", num_inference_steps=15, output="images")[0] +``` + +There are a few key differences though: +1. You can also pass a `PipelineState` object directly to the pipeline instead of individual arguments +2. If you do not specify the `output` argument, it returns the `PipelineState` object +3. You can pass a list as `output`, e.g. `pipeline(... output=["images", "latents"])` will return a dictionary containing both the generated image and the final denoised latents + +Under the hood, `ModularPipeline`'s `__call__` method is a wrapper around the pipeline blocks' `__call__` method: it creates a `PipelineState` object and populates it with user inputs, then returns the output to the user based on the `output` argument. It also ensures that all pipeline-level config and components are exposed to all pipeline blocks by preparing and passing a `components` input. + + + +You can inspect the docstring of a `ModularPipeline` to check what arguments the pipeline accepts and how to specify the `output` you want. It will list all available outputs (basically everything in the intermediate pipeline state) so you can choose from the list. + +```py +t2i_pipeline.doc +``` + + +```py +import torch +from diffusers.modular_pipelines import SequentialPipelineBlocks +from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS + +t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS) + +modular_repo_id = "YiYiXu/modular-loader-t2i" +t2i_pipeline = t2i_blocks.init_pipeline(modular_repo_id) + +t2i_pipeline.load_components(torch_dtype=torch.float16) +t2i_pipeline.to("cuda") + +image = t2i_pipeline(prompt="a cat", output="images")[0] +image.save("modular_t2i_out.png") +``` + + +## An slightly advanced Workflow + +We've learned the basic components of the Modular Diffusers System. Now let's tie everything together with more practical example that demonstrates the true power of Modular Diffusers: working between with multiple pipelines that can share components. + + +```py +import torch +from diffusers.modular_pipelines import SequentialPipelineBlocks, ComponentsManager +from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS, IMAGE2IMAGE_BLOCKS + +# create t2i blocks and then pop out the text_encoder step and decoder step so that we can use them in standalone manner +t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS.copy()) +text_blocks = t2i_blocks.blocks.pop("text_encoder") +decoder_blocks = t2i_blocks.blocks.pop("decode") + +# Create a refiner blocks +# - removing image_encoder a since we'll use latents from t2i +# - removing decode since we already created a seperate decoder_block +i2i_blocks_dict = IMAGE2IMAGE_BLOCKS.copy() +i2i_blocks_dict.pop("image_encoder") +i2i_blocks_dict.pop("decode") +refiner_blocks = SequentialPipelineBlocks.from_blocks_dict(i2i_blocks_dict) + +# Set up component manager and turn on the offloading +components = ComponentsManager() +components.enable_auto_cpu_offload(device="cuda") + +# convert all blocks into runnable pipelines: text_node, decoder_node, t2i_pipe, refiner_pipe +t2i_repo = "YiYiXu/modular-loader-t2i" +refiner_repo = "YiYiXu/modular_refiner" +dtype = torch.float16 + +text_node = text_blocks.init_pipeline(t2i_repo, component_manager=components, collection="t2i") +text_node.load_components(torch_dtype=dtype) + +decoder_node = decoder_blocks.init_pipeline(t2i_repo, component_manager=components, collection="t2i") +decoder_node.load_components(torch_dtype=dtype) + +t2i_pipe = t2i_blocks.init_pipeline(t2i_repo, component_manager=components, collection="t2i") +t2i_pipe.load_components(torch_dtype=dtype) + +# for refiner pipeline, only unet is unique so we only load unet here, and we will reuse other components +refiner_pipe = refiner_blocks.init_pipeline(refiner_repo, component_manager=components, collection="refiner") +refiner_pipe.load_components(component_names="unet", torch_dtype=dtype) +``` + +let's inspect components manager here, you can see that 5 models are automatically registered: two text encoders, two UNets, and one VAE. The models are organized by collection - 4 models under "t2i" and one UNet under "refiner". This happens because we passed a `collection` parameter when initializing each pipeline. For example, when we created the refiner pipeline, we did `refiner_pipe = refiner_blocks.init_pipeline(refiner_repo, component_manager=components, collection="refiner")`. All models loaded by `refiner_pipe.load_components(...)` are automatically placed under the "refiner" collection. + +Notice that all models are currently on CPU with execution device "cuda:0" - this is due to the auto CPU offloading strategy we enabled with `components.enable_auto_cpu_offload(device="cuda")`. + +The manager also displays useful info like dtype and memory size for each model. + +```py +>>> components +Components: +====================================================================================================================================================================================== +Models: +-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- +Name | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection +-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- +text_encoder_2 | CLIPTextModelWithProjection | cpu(cuda:0) | torch.float16 | 1.29 | stabilityai/stable-diffusion-xl-base-1.0|text_encoder_2|null|null | t2i +text_encoder | CLIPTextModel | cpu(cuda:0) | torch.float16 | 0.23 | stabilityai/stable-diffusion-xl-base-1.0|text_encoder|null|null | t2i +unet | UNet2DConditionModel | cpu(cuda:0) | torch.float16 | 4.78 | RunDiffusion/Juggernaut-XL-v9|unet|fp16|null | t2i +unet | UNet2DConditionModel | cpu(cuda:0) | torch.float16 | 4.21 | stabilityai/stable-diffusion-xl-refiner-1.0|unet|null|null | refiner +vae | AutoencoderKL | cpu(cuda:0) | torch.float16 | 0.16 | madebyollin/sdxl-vae-fp16-fix|null|null|null | t2i +-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + +Other Components: +-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- +Name | Class | Collection +-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- +tokenizer | CLIPTokenizer | t2i +tokenizer_2 | CLIPTokenizer | t2i +scheduler | EulerDiscreteScheduler | t2i +-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + +Additional Component Info: +================================================== +``` + + +Now let's reuse components from the t2i pipeline in the refiner. First, let's check the loading status of the refiner pipeline to understand what components are needed: + +```py +>>> refiner_pipe.loader +``` + +Looking at the loader output, you can see that `text_encoder` and `tokenizer` have empty loading spec maps (their `repo` fields are `null`), this is because refiner pipeline does not use these two components so they are not listed in the `modular_model_index.json` in `refiner_repo`. The `unet` is already correctly loaded from the refiner repository. We need to load the remaining components: `vae`, `text_encoder_2`, `tokenizer_2`, and `scheduler`. Since these components are already available in the t2i collection, we can reuse them instead of loading duplicates. + +Now let's reuse the components from the t2i pipeline in the refiner. We use the`|` to select multiple components from components manager at once: + +```py +# Reuse components from t2i pipeline (select everything at once) +reuse_components = components.get("text_encoder_2|scheduler|vae|tokenizer_2", as_name_component_tuples=True) +refiner_pipe.update_components(**dict(reuse_components)) +``` + +You'll see warnings indicating that these components already exist in the components manager: + +```out +component 'text_encoder_2' already exists as 'text_encoder_2_238ae9a7-c864-4837-a8a2-f58ed753b2d0' +component 'tokenizer_2' already exists as 'tokenizer_2_b795af3d-f048-4b07-a770-9e8237a2be2d' +component 'scheduler' already exists as 'scheduler_e3435f63-266a-4427-9383-eb812e830fe8' +component 'vae' already exists as 'vae_357eee6a-4a06-46f1-be83-494f7d60ca69' +``` + +These warnings are expected and indicate that the components manager is correctly identifying that these components are already loaded. The system will reuse the existing components rather than creating duplicates. + +Let's check the components manager again to see the updated state: + +```py +>>> components +Components: +====================================================================================================================================================================================== +Models: +-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- +Name | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection +-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- +text_encoder | CLIPTextModel | cpu(cuda:0) | torch.float16 | 0.23 | stabilityai/stable-diffusion-xl-base-1.0|text_encoder|null|null | t2i +text_encoder_2 | CLIPTextModelWithProjection | cpu(cuda:0) | torch.float16 | 1.29 | stabilityai/stable-diffusion-xl-base-1.0|text_encoder_2|null|null | t2i + | | | | | | refiner +vae | AutoencoderKL | cpu(cuda:0) | torch.float16 | 0.16 | madebyollin/sdxl-vae-fp16-fix|null|null|null | t2i + | | | | | | refiner +unet | UNet2DConditionModel | cpu(cuda:0) | torch.float16 | 4.78 | RunDiffusion/Juggernaut-XL-v9|unet|fp16|null | t2i +unet | UNet2DConditionModel | cpu(cuda:0) | torch.float16 | 4.21 | stabilityai/stable-diffusion-xl-refiner-1.0|unet|null|null | refiner +-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + +Other Components: +-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- +Name | Class | Collection +-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- +tokenizer_2 | CLIPTokenizer | t2i + | | refiner +tokenizer | CLIPTokenizer | t2i +scheduler | EulerDiscreteScheduler | t2i + | | refiner +-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + +Additional Component Info: +================================================== +``` + +Notice how `text_encoder_2`, `vae`, `tokenizer_2`, and `scheduler` now appear under both "t2i" and "refiner" collections. + +We can start to generate an image with the t2i pipeline and refine it. + +First to run the prompt through text_node to get prompt embeddings + + + +💡 don't forget to `text_node.doc` to find out what outputs are available and set the `output` argument accordingly + + + +```py +prompt = "A crystal orb resting on a wooden table with a yellow rubber duck, surrounded by aged scrolls and alchemy tools, illuminated by candlelight, detailed texture, high resolution image" + +text_embeddings = text_node(prompt=prompt, output=["prompt_embeds","negative_prompt_embeds", "pooled_prompt_embeds", "negative_pooled_prompt_embeds"]) +``` + +Now generate latents with t2i pipeline and then refine with refiner. Note that both our `t2i_pipe` and `refiner_pipe` do not have decoder steps since we separated them out earlier, so we need to use `output="latents"` instead of `output="images"`. + + + +💡 `t2i_pipe.blocks` shows you what steps this pipeline takes. You can see that our `t2i_pipe` no longer includes the `text_encoder` and `decode` steps since we removed them earlier when we popped them out to create separate nodes. + +```py +>>> t2i_pipe.blocks +SequentialPipelineBlocks( + Class: ModularPipelineBlocks + + Description: + + + Components: + scheduler (`EulerDiscreteScheduler`) + guider (`ClassifierFreeGuidance`) + unet (`UNet2DConditionModel`) + + Blocks: + [0] input (StableDiffusionXLInputStep) + Description: Input processing step that: + 1. Determines `batch_size` and `dtype` based on `prompt_embeds` + 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt` + + All input tensors are expected to have either batch_size=1 or match the batch_size + of prompt_embeds. The tensors will be duplicated across the batch dimension to + have a final batch_size of batch_size * num_images_per_prompt. + + [1] set_timesteps (StableDiffusionXLSetTimestepsStep) + Description: Step that sets the scheduler's timesteps for inference + + [2] prepare_latents (StableDiffusionXLPrepareLatentsStep) + Description: Prepare latents step that prepares the latents for the text-to-image generation process + + [3] prepare_add_cond (StableDiffusionXLPrepareAdditionalConditioningStep) + Description: Step that prepares the additional conditioning for the text-to-image generation process + + [4] denoise (StableDiffusionXLDenoiseLoop) + Description: Denoise step that iteratively denoise the latents. + Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method + At each iteration, it runs blocks defined in `blocks` sequencially: + - `StableDiffusionXLLoopBeforeDenoiser` + - `StableDiffusionXLLoopDenoiser` + - `StableDiffusionXLLoopAfterDenoiser` + + +) +``` + + + +```py +latents = t2i_pipe(**text_embeddings, num_inference_steps=25, output="latents") +refined_latents = refiner_pipe(image_latents=latents, prompt=prompt, num_inference_steps=10, output="latents") +``` + +To get the final images, we need to pass the latents through our separate decoder node: + +```py +image = decoder_node(latents=latents, output="images")[0] +refined_image = decoder_node(latents=refined_latents, output="images")[0] +``` + +# YiYi TODO: maybe more on controlnet/lora/ip-adapter + + + + + + From ffbaa890ba603dc556bf2076f6f6e68cce8ab840 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 25 Jun 2025 08:55:06 +0200 Subject: [PATCH 084/170] move save_pretrained to the correct place --- .../modular_pipelines/modular_pipeline.py | 35 +++++++++---------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index c26a9c7c8a76..cdb28519f440 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -248,7 +248,7 @@ def format_value(v): class ModularPipelineBlocks(ConfigMixin): """ - Mixin for all PipelineBlocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks + Base class for all Pipeline Blocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks, LoopSequentialPipelineBlocks """ config_name = "config.json" @@ -307,6 +307,20 @@ def from_pretrained( } return block_cls(**block_kwargs) + + def save_pretrained(self, save_directory, push_to_hub = False, **kwargs): + # TODO: factor out this logic. + cls_name = self.__class__.__name__ + + full_mod = type(self).__module__ + module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "") + parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0] + auto_map = {f"{parent_module}": f"{module}.{cls_name}"} + + self.register_to_config(auto_map=auto_map) + self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) + config = dict(self.config) + self._internal_dict = FrozenDict(config) def init_pipeline(self, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None): """ @@ -532,21 +546,6 @@ def add_block_state(self, state: PipelineState, block_state: BlockState): if current_value is not param: # Using identity comparison to check if object was modified state.add_intermediate(param_name, param, input_param.kwargs_type) - def save_pretrained(self, save_directory, push_to_hub = False, **kwargs): - # TODO: factor out this logic. - cls_name = self.__class__.__name__ - - full_mod = type(self).__module__ - module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "") - parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0] - auto_map = {f"{parent_module}": f"{module}.{cls_name}"} - _component_names = [c.name for c in self.expected_components] - - self.register_to_config(auto_map=auto_map, _component_names=_component_names) - self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) - config = dict(self.config) - self._internal_dict = FrozenDict(config) - def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]: """ @@ -2366,9 +2365,7 @@ def __init__(self, blocks: ModularPipelineBlocks, loader: ModularLoader): self.loader = loader def __repr__(self): - blocks_class = self.blocks.__class__.__name__ - loader_class = self.loader.__class__.__name__ - return f"ModularPipeline(blocks={blocks_class}, loader={loader_class})" + return f"ModularPipeline(\n blocks={repr(self.blocks)},\n loader={repr(self.loader)}\n)" @property def default_call_parameters(self) -> Dict[str, Any]: From cdaaa40d31ded7c68aa0c737d6a67b2d17f88e07 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 25 Jun 2025 08:56:08 +0200 Subject: [PATCH 085/170] update ComponentSpec.from_component, only update config if it is created with from_config --- src/diffusers/modular_pipelines/modular_pipeline_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index 868c09106043..c83b2abf50a7 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -117,7 +117,7 @@ def from_component(cls, name: str, component: Any) -> Any: type_hint = component.__class__ default_creation_method = "from_config" if component._diffusers_load_id == "null" else "from_pretrained" - if isinstance(component, ConfigMixin): + if isinstance(component, ConfigMixin) and default_creation_method == "from_config": config = component.config else: config = None From 1c9f0a83c94e0b00c1c49dc1584999be961b806d Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 25 Jun 2025 09:14:19 +0200 Subject: [PATCH 086/170] ujpdate toctree --- docs/source/en/_toctree.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index e940e60b1151..09299310d22d 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -91,6 +91,7 @@ title: Developer Guide - local: modular_diffusers/quicktour title: Quicktour + title: Modular Diffusers - sections: - local: using-diffusers/cogvideox title: CogVideoX From c0327e493e13af61ed38f9f8de2cc5942b744bc6 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 25 Jun 2025 10:49:09 +0200 Subject: [PATCH 087/170] update init --- src/diffusers/__init__.py | 22 +++------------------- 1 file changed, 3 insertions(+), 19 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 73ddfeafeda2..164ee216f3e3 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -565,25 +565,14 @@ "WuerstchenPriorPipeline", ] ) - - -try: - if not (is_torch_available() and is_transformers_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from .utils import dummy_torch_and_transformers_objects # noqa F403 - - _import_structure["utils.dummy_torch_and_transformers_objects"] = [ - name for name in dir(dummy_torch_and_transformers_objects) if not name.startswith("_") - ] - -else: _import_structure["modular_pipelines"].extend( [ "StableDiffusionXLAutoPipeline", "StableDiffusionXLModularLoader", ] ) + + try: if not (is_torch_available() and is_transformers_available() and is_opencv_available()): raise OptionalDependencyNotAvailable() @@ -1193,16 +1182,11 @@ WuerstchenDecoderPipeline, WuerstchenPriorPipeline, ) - try: - if not (is_torch_available() and is_transformers_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from .utils.dummy_torch_and_transformers_objects import * # noqa F403 - else: from .modular_pipelines import ( StableDiffusionXLAutoPipeline, StableDiffusionXLModularLoader, ) + try: if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): raise OptionalDependencyNotAvailable() From 5917d7039f874eaa39a900a12058c414ab39a8da Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 25 Jun 2025 11:04:25 +0200 Subject: [PATCH 088/170] remove lora related changes --- src/diffusers/loaders/lora_base.py | 6 ++---- src/diffusers/loaders/lora_pipeline.py | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 03ec3046ce71..469b4453c1cf 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -738,10 +738,8 @@ def set_adapters( # Decompose weights into weights for denoiser and text encoders. _component_adapter_weights = {} for component in self._lora_loadable_modules: - model = getattr(self, component, None) - if model is None: - logger.warning(f"Model {component} not found in pipeline.") - continue + model = getattr(self, component) + for adapter_name, weights in zip(adapter_names, adapter_weights): if isinstance(weights, dict): component_adapter_weights = weights.pop(component, None) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 3b2de77e4db3..4fea005cbc39 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -675,7 +675,7 @@ def load_lora_weights( kwargs["return_lora_metadata"] = True state_dict, network_alphas, metadata = self.lora_state_dict( pretrained_model_name_or_path_or_dict, - unet_config=self.unet.config if hasattr(self, "unet") else None, + unet_config=self.unet.config, **kwargs, ) From 8c038f0e62c594f0b1d72ed6079e3068617ce787 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Tue, 24 Jun 2025 23:05:23 -1000 Subject: [PATCH 089/170] Update src/diffusers/loaders/lora_base.py --- src/diffusers/loaders/lora_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 469b4453c1cf..16f0d4836505 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -438,7 +438,7 @@ def _func_optionally_disable_offloading(_pipeline): is_model_cpu_offload = False is_sequential_cpu_offload = False - if _pipeline is not None and hasattr(_pipeline, "hf_device_map") and _pipeline.hf_device_map is None: + if _pipeline is not None and _pipeline.hf_device_map is None: for _, component in _pipeline.components.items(): if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): if not is_model_cpu_offload: From cb328d3ff988d179d62b012d0c0b91d0964ab7bc Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Tue, 24 Jun 2025 23:12:26 -1000 Subject: [PATCH 090/170] Apply suggestions from code review --- src/diffusers/pipelines/pipeline_utils.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 6431a48dc87d..ccc714289df9 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -466,7 +466,7 @@ def module_is_offloaded(module): module_is_sequentially_offloaded(module) for _, module in self.components.items() ) - is_pipeline_device_mapped = hasattr(self, "hf_device_map") and self.hf_device_map is not None and len(self.hf_device_map) > 1 + is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1 if is_pipeline_device_mapped: raise ValueError( "It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline." @@ -483,7 +483,6 @@ def module_is_offloaded(module): "You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation." ) - # Display a warning in this case (the operation succeeds but the benefits are lost) pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items()) if pipeline_is_offloaded and device_type in ["cuda", "xpu"]: @@ -1162,11 +1161,9 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will automatically detect the available accelerator and use. """ - self._maybe_raise_error_if_group_offload_active(raise_error=True) - is_pipeline_device_mapped = hasattr(self, "hf_device_map") and self.hf_device_map is not None and len(self.hf_device_map) > 1 - + is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1 if is_pipeline_device_mapped: raise ValueError( "It seems like you have activated a device mapping strategy on the pipeline so calling `enable_model_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_model_cpu_offload()`." @@ -1288,7 +1285,7 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") self.remove_all_hooks() - is_pipeline_device_mapped = hasattr(self, "hf_device_map") and self.hf_device_map is not None and len(self.hf_device_map) > 1 + is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1 if is_pipeline_device_mapped: raise ValueError( "It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`." From 7d2a633e02724ded6960d9cd4e5515c366e82698 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 25 Jun 2025 11:26:36 +0200 Subject: [PATCH 091/170] style --- src/diffusers/__init__.py | 24 +- src/diffusers/commands/custom_blocks.py | 11 +- src/diffusers/commands/diffusers_cli.py | 2 +- .../guiders/adaptive_projected_guidance.py | 23 +- src/diffusers/guiders/auto_guidance.py | 21 +- .../guiders/classifier_free_guidance.py | 15 +- .../classifier_free_zero_star_guidance.py | 17 +- src/diffusers/guiders/guider_utils.py | 16 +- src/diffusers/guiders/skip_layer_guidance.py | 25 +- .../guiders/smoothed_energy_guidance.py | 25 +- .../tangential_classifier_free_guidance.py | 19 +- src/diffusers/hooks/layer_skip.py | 15 +- .../hooks/smoothed_energy_guidance_utils.py | 20 +- src/diffusers/loaders/__init__.py | 2 +- src/diffusers/modular_pipelines/__init__.py | 4 +- .../modular_pipelines/components_manager.py | 175 ++-- .../modular_pipelines/modular_pipeline.py | 335 ++++---- .../modular_pipeline_utils.py | 167 ++-- src/diffusers/modular_pipelines/node_utils.py | 93 ++- .../stable_diffusion_xl/__init__.py | 21 +- .../stable_diffusion_xl/before_denoise.py | 249 +++--- .../stable_diffusion_xl/decoders.py | 28 +- .../stable_diffusion_xl/denoise.py | 782 ++---------------- .../stable_diffusion_xl/encoders.py | 119 ++- .../modular_block_mappings.py | 42 +- .../stable_diffusion_xl/modular_loader.py | 9 +- .../modular_pipeline_presets.py | 11 +- src/diffusers/utils/dynamic_modules_utils.py | 4 +- 28 files changed, 825 insertions(+), 1449 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 164ee216f3e3..18d90be500b7 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -794,8 +794,8 @@ LayerSkipConfig, PyramidAttentionBroadcastConfig, SmoothedEnergyGuidanceConfig, - apply_layer_skip, apply_faster_cache, + apply_layer_skip, apply_pyramid_attention_broadcast, ) from .models import ( @@ -875,6 +875,13 @@ WanTransformer3DModel, WanVACETransformer3DModel, ) + from .modular_pipelines import ( + ComponentsManager, + ComponentSpec, + ModularLoader, + ModularPipeline, + ModularPipelineBlocks, + ) from .optimization import ( get_constant_schedule, get_constant_schedule_with_warmup, @@ -907,13 +914,6 @@ ScoreSdeVePipeline, StableDiffusionMixin, ) - from .modular_pipelines import ( - ModularLoader, - ModularPipeline, - ModularPipelineBlocks, - ComponentSpec, - ComponentsManager, - ) from .quantizers import DiffusersQuantizer from .schedulers import ( AmusedScheduler, @@ -978,6 +978,10 @@ except OptionalDependencyNotAvailable: from .utils.dummy_torch_and_transformers_objects import * # noqa F403 else: + from .modular_pipelines import ( + StableDiffusionXLAutoPipeline, + StableDiffusionXLModularLoader, + ) from .pipelines import ( AllegroPipeline, AltDiffusionImg2ImgPipeline, @@ -1182,10 +1186,6 @@ WuerstchenDecoderPipeline, WuerstchenPriorPipeline, ) - from .modular_pipelines import ( - StableDiffusionXLAutoPipeline, - StableDiffusionXLModularLoader, - ) try: if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): diff --git a/src/diffusers/commands/custom_blocks.py b/src/diffusers/commands/custom_blocks.py index d2f2de3a8f9a..f532e8b775fd 100644 --- a/src/diffusers/commands/custom_blocks.py +++ b/src/diffusers/commands/custom_blocks.py @@ -18,10 +18,11 @@ """ import ast -from argparse import ArgumentParser, Namespace -from pathlib import Path import importlib.util import os +from argparse import ArgumentParser, Namespace +from pathlib import Path + from ..utils import logging from . import BaseDiffusersCLICommand @@ -57,7 +58,7 @@ def run(self): # determine the block to be saved. out = self._get_class_names(self.block_module_name) classes_found = list({cls for cls, _ in out}) - + if self.block_class_name is not None: child_class, parent_class = self._choose_block(out, self.block_class_name) if child_class is None and parent_class is None: @@ -125,9 +126,9 @@ def _get_base_name(self, node: ast.expr): val = self._get_base_name(node.value) return f"{val}.{node.attr}" if val else node.attr return None - + def _create_automap(self, parent_class, child_class): module = str(self.block_module_name).replace(".py", "").rsplit(".", 1)[-1] auto_map = {f"{parent_class}": f"{module}.{child_class}"} return {"auto_map": auto_map} - + diff --git a/src/diffusers/commands/diffusers_cli.py b/src/diffusers/commands/diffusers_cli.py index f291303d1e79..a27ac24f2a3e 100644 --- a/src/diffusers/commands/diffusers_cli.py +++ b/src/diffusers/commands/diffusers_cli.py @@ -15,9 +15,9 @@ from argparse import ArgumentParser +from .custom_blocks import CustomBlocksCommand from .env import EnvironmentCommand from .fp16_safetensors import FP16SafetensorsCommand -from .custom_blocks import CustomBlocksCommand def main(): diff --git a/src/diffusers/guiders/adaptive_projected_guidance.py b/src/diffusers/guiders/adaptive_projected_guidance.py index ef2f3f2c8420..f1a6096c4d6a 100644 --- a/src/diffusers/guiders/adaptive_projected_guidance.py +++ b/src/diffusers/guiders/adaptive_projected_guidance.py @@ -13,12 +13,13 @@ # limitations under the License. import math -from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch from .guider_utils import BaseGuidance, rescale_noise_cfg + if TYPE_CHECKING: from ..modular_pipelines.modular_pipeline import BlockState @@ -74,10 +75,10 @@ def __init__( self.momentum_buffer = None def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: - + if input_fields is None: input_fields = self._input_fields - + if self._step == 0: if self.adaptive_projected_guidance_momentum is not None: self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum) @@ -123,19 +124,19 @@ def num_conditions(self) -> int: def _is_apg_enabled(self) -> bool: if not self._enabled: return False - + is_within_range = True if self._num_inference_steps is not None: skip_start_step = int(self._start * self._num_inference_steps) skip_stop_step = int(self._stop * self._num_inference_steps) is_within_range = skip_start_step <= self._step < skip_stop_step - + is_close = False if self.use_original_formulation: is_close = math.isclose(self.guidance_scale, 0.0) else: is_close = math.isclose(self.guidance_scale, 1.0) - + return is_within_range and not is_close @@ -160,25 +161,25 @@ def normalized_guidance( ): diff = pred_cond - pred_uncond dim = [-i for i in range(1, len(diff.shape))] - + if momentum_buffer is not None: momentum_buffer.update(diff) diff = momentum_buffer.running_average - + if norm_threshold > 0: ones = torch.ones_like(diff) diff_norm = diff.norm(p=2, dim=dim, keepdim=True) scale_factor = torch.minimum(ones, norm_threshold / diff_norm) diff = diff * scale_factor - + v0, v1 = diff.double(), pred_cond.double() v1 = torch.nn.functional.normalize(v1, dim=dim) v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1 v0_orthogonal = v0 - v0_parallel diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff) normalized_update = diff_orthogonal + eta * diff_parallel - + pred = pred_cond if use_original_formulation else pred_uncond pred = pred + guidance_scale * normalized_update - + return pred diff --git a/src/diffusers/guiders/auto_guidance.py b/src/diffusers/guiders/auto_guidance.py index 791cc582add2..83120c20ceca 100644 --- a/src/diffusers/guiders/auto_guidance.py +++ b/src/diffusers/guiders/auto_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch @@ -21,6 +21,7 @@ from ..hooks.layer_skip import _apply_layer_skip_hook from .guider_utils import BaseGuidance, rescale_noise_cfg + if TYPE_CHECKING: from ..modular_pipelines.modular_pipeline import BlockState @@ -113,18 +114,18 @@ def prepare_models(self, denoiser: torch.nn.Module) -> None: if self._is_ag_enabled() and self.is_unconditional: for name, config in zip(self._auto_guidance_hook_names, self.auto_guidance_config): _apply_layer_skip_hook(denoiser, config, name=name) - + def cleanup_models(self, denoiser: torch.nn.Module) -> None: if self._is_ag_enabled() and self.is_unconditional: for name in self._auto_guidance_hook_names: registry = HookRegistry.check_if_exists_or_initialize(denoiser) registry.remove_hook(name, recurse=True) - + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: - + if input_fields is None: input_fields = self._input_fields - + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] data_batches = [] for i in range(self.num_conditions): @@ -144,9 +145,9 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = if self.guidance_rescale > 0.0: pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) - + return pred, {} - + @property def is_conditional(self) -> bool: return self._count_prepared == 1 @@ -161,17 +162,17 @@ def num_conditions(self) -> int: def _is_ag_enabled(self) -> bool: if not self._enabled: return False - + is_within_range = True if self._num_inference_steps is not None: skip_start_step = int(self._start * self._num_inference_steps) skip_stop_step = int(self._stop * self._num_inference_steps) is_within_range = skip_start_step <= self._step < skip_stop_step - + is_close = False if self.use_original_formulation: is_close = math.isclose(self.guidance_scale, 0.0) else: is_close = math.isclose(self.guidance_scale, 1.0) - + return is_within_range and not is_close diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py index a459e51cd083..faeba0971157 100644 --- a/src/diffusers/guiders/classifier_free_guidance.py +++ b/src/diffusers/guiders/classifier_free_guidance.py @@ -13,12 +13,13 @@ # limitations under the License. import math -from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch from .guider_utils import BaseGuidance, rescale_noise_cfg + if TYPE_CHECKING: from ..modular_pipelines.modular_pipeline import BlockState @@ -74,12 +75,12 @@ def __init__( self.guidance_scale = guidance_scale self.guidance_rescale = guidance_rescale self.use_original_formulation = use_original_formulation - + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: - + if input_fields is None: input_fields = self._input_fields - + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] data_batches = [] for i in range(self.num_conditions): @@ -116,17 +117,17 @@ def num_conditions(self) -> int: def _is_cfg_enabled(self) -> bool: if not self._enabled: return False - + is_within_range = True if self._num_inference_steps is not None: skip_start_step = int(self._start * self._num_inference_steps) skip_stop_step = int(self._stop * self._num_inference_steps) is_within_range = skip_start_step <= self._step < skip_stop_step - + is_close = False if self.use_original_formulation: is_close = math.isclose(self.guidance_scale, 0.0) else: is_close = math.isclose(self.guidance_scale, 1.0) - + return is_within_range and not is_close diff --git a/src/diffusers/guiders/classifier_free_zero_star_guidance.py b/src/diffusers/guiders/classifier_free_zero_star_guidance.py index a722f2605036..b4dee9295ab6 100644 --- a/src/diffusers/guiders/classifier_free_zero_star_guidance.py +++ b/src/diffusers/guiders/classifier_free_zero_star_guidance.py @@ -13,12 +13,13 @@ # limitations under the License. import math -from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch from .guider_utils import BaseGuidance, rescale_noise_cfg + if TYPE_CHECKING: from ..modular_pipelines.modular_pipeline import BlockState @@ -72,12 +73,12 @@ def __init__( self.zero_init_steps = zero_init_steps self.guidance_rescale = guidance_rescale self.use_original_formulation = use_original_formulation - + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: - + if input_fields is None: input_fields = self._input_fields - + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] data_batches = [] for i in range(self.num_conditions): @@ -106,7 +107,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) return pred, {} - + @property def is_conditional(self) -> bool: return self._count_prepared == 1 @@ -121,19 +122,19 @@ def num_conditions(self) -> int: def _is_cfg_enabled(self) -> bool: if not self._enabled: return False - + is_within_range = True if self._num_inference_steps is not None: skip_start_step = int(self._start * self._num_inference_steps) skip_stop_step = int(self._stop * self._num_inference_steps) is_within_range = skip_start_step <= self._step < skip_stop_step - + is_close = False if self.use_original_formulation: is_close = math.isclose(self.guidance_scale, 0.0) else: is_close = math.isclose(self.guidance_scale, 1.0) - + return is_within_range and not is_close diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index e8e873f5c88f..87109eb048ed 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -58,10 +58,10 @@ def __init__(self, start: float = 0.0, stop: float = 1.0): def disable(self): self._enabled = False - + def enable(self): self._enabled = True - + def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None: self._step = step self._num_inference_steps = num_inference_steps @@ -104,14 +104,14 @@ def set_input_fields(self, **kwargs: Dict[str, Union[str, Tuple[str, str]]]) -> f"Expected `set_input_fields` to be called with a string or a tuple of string with length 2, but got {type(value)} for key {key}." ) self._input_fields = kwargs - + def prepare_models(self, denoiser: torch.nn.Module) -> None: """ Prepares the models for the guidance technique on a given batch of data. This method should be overridden in subclasses to implement specific model preparation logic. """ self._count_prepared += 1 - + def cleanup_models(self, denoiser: torch.nn.Module) -> None: """ Cleans up the models for the guidance technique after a given batch of data. This method should be overridden in @@ -119,7 +119,7 @@ def cleanup_models(self, denoiser: torch.nn.Module) -> None: modifications made during `prepare_models`. """ pass - + def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.") @@ -139,15 +139,15 @@ def forward(self, *args, **kwargs) -> Any: @property def is_conditional(self) -> bool: raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.") - + @property def is_unconditional(self) -> bool: return not self.is_conditional - + @property def num_conditions(self) -> int: raise NotImplementedError("BaseGuidance::num_conditions must be implemented in subclasses.") - + @classmethod def _prepare_batch(cls, input_fields: Dict[str, Union[str, Tuple[str, str]]], data: "BlockState", tuple_index: int, identifier: str) -> "BlockState": """ diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py index 7c19f6391f41..ffe00ea7db33 100644 --- a/src/diffusers/guiders/skip_layer_guidance.py +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch @@ -21,6 +21,7 @@ from ..hooks.layer_skip import _apply_layer_skip_hook from .guider_utils import BaseGuidance, rescale_noise_cfg + if TYPE_CHECKING: from ..modular_pipelines.modular_pipeline import BlockState @@ -148,19 +149,19 @@ def prepare_models(self, denoiser: torch.nn.Module) -> None: if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1: for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config): _apply_layer_skip_hook(denoiser, config, name=name) - + def cleanup_models(self, denoiser: torch.nn.Module) -> None: if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1: registry = HookRegistry.check_if_exists_or_initialize(denoiser) # Remove the hooks after inference for hook_name in self._skip_layer_hook_names: registry.remove_hook(hook_name, recurse=True) - + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: - + if input_fields is None: input_fields = self._input_fields - + if self.num_conditions == 1: tuple_indices = [0] input_predictions = ["pred_cond"] @@ -204,7 +205,7 @@ def forward( pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) return pred, {} - + @property def is_conditional(self) -> bool: return self._count_prepared == 1 or self._count_prepared == 3 @@ -221,31 +222,31 @@ def num_conditions(self) -> int: def _is_cfg_enabled(self) -> bool: if not self._enabled: return False - + is_within_range = True if self._num_inference_steps is not None: skip_start_step = int(self._start * self._num_inference_steps) skip_stop_step = int(self._stop * self._num_inference_steps) is_within_range = skip_start_step <= self._step < skip_stop_step - + is_close = False if self.use_original_formulation: is_close = math.isclose(self.guidance_scale, 0.0) else: is_close = math.isclose(self.guidance_scale, 1.0) - + return is_within_range and not is_close def _is_slg_enabled(self) -> bool: if not self._enabled: return False - + is_within_range = True if self._num_inference_steps is not None: skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps) skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps) is_within_range = skip_start_step < self._step < skip_stop_step - + is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0) - + return is_within_range and not is_zero diff --git a/src/diffusers/guiders/smoothed_energy_guidance.py b/src/diffusers/guiders/smoothed_energy_guidance.py index 3986da913f82..ab21b6d9526d 100644 --- a/src/diffusers/guiders/smoothed_energy_guidance.py +++ b/src/diffusers/guiders/smoothed_energy_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch @@ -21,6 +21,7 @@ from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook from .guider_utils import BaseGuidance, rescale_noise_cfg + if TYPE_CHECKING: from ..modular_pipelines.modular_pipeline import BlockState @@ -141,19 +142,19 @@ def prepare_models(self, denoiser: torch.nn.Module) -> None: if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1: for name, config in zip(self._seg_layer_hook_names, self.seg_guidance_config): _apply_smoothed_energy_guidance_hook(denoiser, config, self.seg_blur_sigma, name=name) - + def cleanup_models(self, denoiser: torch.nn.Module): if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1: registry = HookRegistry.check_if_exists_or_initialize(denoiser) # Remove the hooks after inference for hook_name in self._seg_layer_hook_names: registry.remove_hook(hook_name, recurse=True) - + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: - + if input_fields is None: input_fields = self._input_fields - + if self.num_conditions == 1: tuple_indices = [0] input_predictions = ["pred_cond"] @@ -197,7 +198,7 @@ def forward( pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) return pred, {} - + @property def is_conditional(self) -> bool: return self._count_prepared == 1 or self._count_prepared == 3 @@ -214,31 +215,31 @@ def num_conditions(self) -> int: def _is_cfg_enabled(self) -> bool: if not self._enabled: return False - + is_within_range = True if self._num_inference_steps is not None: skip_start_step = int(self._start * self._num_inference_steps) skip_stop_step = int(self._stop * self._num_inference_steps) is_within_range = skip_start_step <= self._step < skip_stop_step - + is_close = False if self.use_original_formulation: is_close = math.isclose(self.guidance_scale, 0.0) else: is_close = math.isclose(self.guidance_scale, 1.0) - + return is_within_range and not is_close def _is_seg_enabled(self) -> bool: if not self._enabled: return False - + is_within_range = True if self._num_inference_steps is not None: skip_start_step = int(self.seg_guidance_start * self._num_inference_steps) skip_stop_step = int(self.seg_guidance_stop * self._num_inference_steps) is_within_range = skip_start_step < self._step < skip_stop_step - + is_zero = math.isclose(self.seg_guidance_scale, 0.0) - + return is_within_range and not is_zero diff --git a/src/diffusers/guiders/tangential_classifier_free_guidance.py b/src/diffusers/guiders/tangential_classifier_free_guidance.py index 017693fd9f07..fdcdaf8dcb3a 100644 --- a/src/diffusers/guiders/tangential_classifier_free_guidance.py +++ b/src/diffusers/guiders/tangential_classifier_free_guidance.py @@ -13,12 +13,13 @@ # limitations under the License. import math -from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch from .guider_utils import BaseGuidance, rescale_noise_cfg + if TYPE_CHECKING: from ..modular_pipelines.modular_pipeline import BlockState @@ -63,10 +64,10 @@ def __init__( self.use_original_formulation = use_original_formulation def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: - + if input_fields is None: input_fields = self._input_fields - + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] data_batches = [] for i in range(self.num_conditions): @@ -101,24 +102,24 @@ def num_conditions(self) -> int: def _is_tcfg_enabled(self) -> bool: if not self._enabled: return False - + is_within_range = True if self._num_inference_steps is not None: skip_start_step = int(self._start * self._num_inference_steps) skip_stop_step = int(self._stop * self._num_inference_steps) is_within_range = skip_start_step <= self._step < skip_stop_step - + is_close = False if self.use_original_formulation: is_close = math.isclose(self.guidance_scale, 0.0) else: is_close = math.isclose(self.guidance_scale, 1.0) - + return is_within_range and not is_close def normalized_guidance(pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guidance_scale: float, use_original_formulation: bool = False) -> torch.Tensor: - cond_dtype = pred_cond.dtype + cond_dtype = pred_cond.dtype preds = torch.stack([pred_cond, pred_uncond], dim=1).float() preds = preds.flatten(2) U, S, Vh = torch.linalg.svd(preds, full_matrices=False) @@ -129,9 +130,9 @@ def normalized_guidance(pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guid x_Vh = torch.matmul(uncond_flat, Vh.transpose(-2, -1)) x_Vh_V = torch.matmul(x_Vh, Vh_modified) pred_uncond = x_Vh_V.reshape(pred_uncond.shape).to(cond_dtype) - + pred = pred_cond if use_original_formulation else pred_uncond shift = pred_cond - pred_uncond pred = pred + guidance_scale * shift - + return pred diff --git a/src/diffusers/hooks/layer_skip.py b/src/diffusers/hooks/layer_skip.py index 65a99464ba2f..6b847271c97b 100644 --- a/src/diffusers/hooks/layer_skip.py +++ b/src/diffusers/hooks/layer_skip.py @@ -20,7 +20,12 @@ from ..utils import get_logger from ..utils.torch_utils import unwrap_module -from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES, _get_submodule_from_fqn +from ._common import ( + _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, + _ATTENTION_CLASSES, + _FEEDFORWARD_CLASSES, + _get_submodule_from_fqn, +) from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry from .hooks import HookRegistry, ModelHook @@ -198,15 +203,15 @@ def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, nam for i, block in enumerate(transformer_blocks): if i not in config.indices: continue - + blocks_found = True - + if config.skip_attention and config.skip_ff: logger.debug(f"Applying TransformerBlockSkipHook to '{config.fqn}.{i}'") registry = HookRegistry.check_if_exists_or_initialize(block) hook = TransformerBlockSkipHook(config.dropout) registry.register_hook(hook, name) - + elif config.skip_attention or config.skip_attention_scores: for submodule_name, submodule in block.named_modules(): if isinstance(submodule, _ATTENTION_CLASSES) and not submodule.is_cross_attention: @@ -215,7 +220,7 @@ def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, nam registry = HookRegistry.check_if_exists_or_initialize(submodule) hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores, config.dropout) registry.register_hook(hook, name) - + if config.skip_ff: for submodule_name, submodule in block.named_modules(): if isinstance(submodule, _FEEDFORWARD_CLASSES): diff --git a/src/diffusers/hooks/smoothed_energy_guidance_utils.py b/src/diffusers/hooks/smoothed_energy_guidance_utils.py index f0366e29887f..353ce7289444 100644 --- a/src/diffusers/hooks/smoothed_energy_guidance_utils.py +++ b/src/diffusers/hooks/smoothed_energy_guidance_utils.py @@ -14,7 +14,7 @@ import math from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple +from typing import List, Optional import torch import torch.nn.functional as F @@ -67,7 +67,7 @@ def post_forward(self, module: torch.nn.Module, output: torch.Tensor) -> torch.T def _apply_smoothed_energy_guidance_hook(module: torch.nn.Module, config: SmoothedEnergyGuidanceConfig, blur_sigma: float, name: Optional[str] = None) -> None: name = name or _SMOOTHED_ENERGY_GUIDANCE_HOOK - + if config.fqn == "auto": for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS: if hasattr(module, identifier): @@ -78,18 +78,18 @@ def _apply_smoothed_energy_guidance_hook(module: torch.nn.Module, config: Smooth "Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid " "`fqn` (fully qualified name) that identifies a stack of transformer blocks." ) - + if config._query_proj_identifiers is None: config._query_proj_identifiers = ["to_q"] - + transformer_blocks = _get_submodule_from_fqn(module, config.fqn) blocks_found = False for i, block in enumerate(transformer_blocks): if i not in config.indices: continue - + blocks_found = True - + for submodule_name, submodule in block.named_modules(): if not isinstance(submodule, _ATTENTION_CLASSES) or submodule.is_cross_attention: continue @@ -103,7 +103,7 @@ def _apply_smoothed_energy_guidance_hook(module: torch.nn.Module, config: Smooth registry = HookRegistry.check_if_exists_or_initialize(query_proj) hook = SmoothedEnergyGuidanceHook(blur_sigma) registry.register_hook(hook, name) - + if not blocks_found: raise ValueError( f"Could not find any transformer blocks matching the provided indices {config.indices} and " @@ -124,7 +124,7 @@ def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma in the future without warning or guarantee of reproducibility. """ assert query.ndim == 3 - + is_inf = sigma > sigma_threshold_inf batch_size, seq_len, embed_dim = query.shape @@ -133,7 +133,7 @@ def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma query_slice = query[:, :num_square_tokens, :] query_slice = query_slice.permute(0, 2, 1) query_slice = query_slice.reshape(batch_size, embed_dim, seq_len_sqrt, seq_len_sqrt) - + if is_inf: kernel_size = min(kernel_size, seq_len_sqrt - (seq_len_sqrt % 2 - 1)) kernel_size_half = (kernel_size - 1) / 2 @@ -154,5 +154,5 @@ def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma query_slice = query_slice.reshape(batch_size, embed_dim, num_square_tokens) query_slice = query_slice.permute(0, 2, 1) query[:, :num_square_tokens, :] = query_slice.clone() - + return query diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index a5f5e6376b04..335d7e623f07 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -102,8 +102,8 @@ def text_encoder_attn_modules(text_encoder): from .ip_adapter import ( FluxIPAdapterMixin, IPAdapterMixin, - SD3IPAdapterMixin, ModularIPAdapterMixin, + SD3IPAdapterMixin, ) from .lora_pipeline import ( AmusedLoraLoaderMixin, diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index 4499634d9fbd..f6e398268ca0 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -49,13 +49,14 @@ except OptionalDependencyNotAvailable: from ..utils.dummy_pt_objects import * # noqa F403 else: + from .components_manager import ComponentsManager from .modular_pipeline import ( AutoPipelineBlocks, BlockState, LoopSequentialPipelineBlocks, ModularLoader, - ModularPipelineBlocks, ModularPipeline, + ModularPipelineBlocks, PipelineBlock, PipelineState, SequentialPipelineBlocks, @@ -70,7 +71,6 @@ StableDiffusionXLAutoPipeline, StableDiffusionXLModularLoader, ) - from .components_manager import ComponentsManager else: import sys diff --git a/src/diffusers/modular_pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py index bdc24d474a32..3f22fa7115be 100644 --- a/src/diffusers/modular_pipelines/components_manager.py +++ b/src/diffusers/modular_pipelines/components_manager.py @@ -12,24 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy +import time +import uuid from collections import OrderedDict from itertools import combinations -from typing import List, Optional, Union, Dict, Any -import copy +from typing import Any, Dict, List, Optional, Union import torch -import time -from dataclasses import dataclass from ..utils import ( is_accelerate_available, logging, ) -from ..models.modeling_utils import ModelMixin -from .modular_pipeline_utils import ComponentSpec - - -import uuid if is_accelerate_available(): @@ -237,12 +232,12 @@ def search_best_candidate(module_sizes, min_memory_offload): class ComponentsManager: def __init__(self): self.components = OrderedDict() - self.added_time = OrderedDict() # Store when components were added + self.added_time = OrderedDict() # Store when components were added self.collections = OrderedDict() # collection_name -> set of component_names self.model_hooks = None self._auto_offload_enabled = False - + def _lookup_ids(self, name=None, collection=None, load_id=None, components: OrderedDict = None): """ Lookup component_ids by name, collection, or load_id. @@ -251,7 +246,7 @@ def _lookup_ids(self, name=None, collection=None, load_id=None, components: Orde components = self.components if name: - ids_by_name = set() + ids_by_name = set() for component_id, component in components.items(): comp_name = self._id_to_name(component_id) if comp_name == name: @@ -272,16 +267,16 @@ def _lookup_ids(self, name=None, collection=None, load_id=None, components: Orde ids_by_load_id.add(name) else: ids_by_load_id = set(components.keys()) - + ids = ids_by_name.intersection(ids_by_collection).intersection(ids_by_load_id) return ids - + @staticmethod def _id_to_name(component_id: str): return "_".join(component_id.split("_")[:-1]) - + def add(self, name, component, collection: Optional[str] = None): - + component_id = f"{name}_{uuid.uuid4()}" # check for duplicated components @@ -305,7 +300,7 @@ def add(self, name, component, collection: Optional[str] = None): if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null": components_with_same_load_id = self._lookup_ids(load_id=component._diffusers_load_id) components_with_same_load_id = [id for id in components_with_same_load_id if id != component_id] - + if components_with_same_load_id: existing = ", ".join(components_with_same_load_id) logger.warning( @@ -320,7 +315,7 @@ def add(self, name, component, collection: Optional[str] = None): if collection: if collection not in self.collections: self.collections[collection] = set() - if not component_id in self.collections[collection]: + if component_id not in self.collections[collection]: comp_ids_in_collection = self._lookup_ids(name=name, collection=collection) for comp_id in comp_ids_in_collection: logger.info(f"Removing existing {name} from collection '{collection}': {comp_id}") @@ -331,8 +326,8 @@ def add(self, name, component, collection: Optional[str] = None): logger.info(f"Added component '{name}' as '{component_id}'") if self._auto_offload_enabled: - self.enable_auto_cpu_offload(self._auto_offload_device) - + self.enable_auto_cpu_offload(self._auto_offload_device) + return component_id @@ -341,14 +336,14 @@ def remove(self, component_id: str = None): if component_id not in self.components: logger.warning(f"Component '{component_id}' not found in ComponentsManager") return - + component = self.components.pop(component_id) self.added_time.pop(component_id) for collection in self.collections: if component_id in self.collections[collection]: self.collections[collection].remove(component_id) - + if self._auto_offload_enabled: self.enable_auto_cpu_offload(self._auto_offload_device) else: @@ -386,7 +381,7 @@ def get(self, names: Union[str, List[str]] = None, collection: Optional[str] = N Dictionary mapping component IDs to components, or list of (base_name, component) tuples if as_name_component_tuples=True """ - + selected_ids = self._lookup_ids(collection=collection, load_id=load_id) components = {k: self.components[k] for k in selected_ids} @@ -397,16 +392,16 @@ def get_base_name(component_id): if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]: return '_'.join(parts[:-1]) return component_id - + if names is None: if as_name_component_tuples: return [(get_base_name(comp_id), comp) for comp_id, comp in components.items()] else: return components - + # Create mapping from component_id to base_name for all components base_names = {comp_id: get_base_name(comp_id) for comp_id in components.keys()} - + def matches_pattern(component_id, pattern, exact_match=False): """ Helper function to check if a component matches a pattern based on its base name. @@ -417,124 +412,124 @@ def matches_pattern(component_id, pattern, exact_match=False): exact_match: If True, only exact matches to base_name are considered """ base_name = base_names[component_id] - + # Exact match with base name if exact_match: return pattern == base_name - + # Prefix match (ends with *) elif pattern.endswith('*'): prefix = pattern[:-1] return base_name.startswith(prefix) - + # Contains match (starts with *) elif pattern.startswith('*'): search = pattern[1:-1] if pattern.endswith('*') else pattern[1:] return search in base_name - + # Exact match (no wildcards) else: return pattern == base_name - + if isinstance(names, str): # Check if this is a "not" pattern is_not_pattern = names.startswith('!') if is_not_pattern: names = names[1:] # Remove the ! prefix - + # Handle OR patterns (containing |) if '|' in names: terms = names.split('|') matches = {} - + for comp_id, comp in components.items(): # For OR patterns with exact names (no wildcards), we do exact matching on base names exact_match = all(not (term.startswith('*') or term.endswith('*')) for term in terms) - + # Check if any of the terms match this component should_include = any(matches_pattern(comp_id, term, exact_match) for term in terms) - + # Flip the decision if this is a NOT pattern if is_not_pattern: should_include = not should_include - + if should_include: matches[comp_id] = comp - + log_msg = "NOT " if is_not_pattern else "" match_type = "exactly matching" if exact_match else "matching any of patterns" logger.info(f"Getting components {log_msg}{match_type} {terms}: {list(matches.keys())}") - + # Try exact match with a base name elif any(names == base_name for base_name in base_names.values()): # Find all components with this base name matches = { - comp_id: comp for comp_id, comp in components.items() + comp_id: comp for comp_id, comp in components.items() if (base_names[comp_id] == names) != is_not_pattern } - + if is_not_pattern: logger.info(f"Getting all components except those with base name '{names}': {list(matches.keys())}") else: logger.info(f"Getting components with base name '{names}': {list(matches.keys())}") - + # Prefix match (ends with *) elif names.endswith('*'): prefix = names[:-1] matches = { - comp_id: comp for comp_id, comp in components.items() + comp_id: comp for comp_id, comp in components.items() if base_names[comp_id].startswith(prefix) != is_not_pattern } if is_not_pattern: logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}") else: logger.info(f"Getting components starting with '{prefix}': {list(matches.keys())}") - + # Contains match (starts with *) elif names.startswith('*'): search = names[1:-1] if names.endswith('*') else names[1:] matches = { - comp_id: comp for comp_id, comp in components.items() + comp_id: comp for comp_id, comp in components.items() if (search in base_names[comp_id]) != is_not_pattern } if is_not_pattern: logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}") else: logger.info(f"Getting components containing '{search}': {list(matches.keys())}") - + # Substring match (no wildcards, but not an exact component name) elif any(names in base_name for base_name in base_names.values()): matches = { - comp_id: comp for comp_id, comp in components.items() + comp_id: comp for comp_id, comp in components.items() if (names in base_names[comp_id]) != is_not_pattern } if is_not_pattern: logger.info(f"Getting components NOT containing '{names}': {list(matches.keys())}") else: logger.info(f"Getting components containing '{names}': {list(matches.keys())}") - + else: raise ValueError(f"Component or pattern '{names}' not found in ComponentsManager") - + if not matches: raise ValueError(f"No components found matching pattern '{names}'") - + if as_name_component_tuples: return [(base_names[comp_id], comp) for comp_id, comp in matches.items()] else: return matches - + elif isinstance(names, list): results = {} for name in names: result = self.get(name, collection, load_id, as_name_component_tuples=False) results.update(result) - + if as_name_component_tuples: return [(base_names[comp_id], comp) for comp_id, comp in results.items()] else: return results - + else: raise ValueError(f"Invalid type for names: {type(names)}") @@ -595,14 +590,14 @@ def get_model_info(self, component_id: str, fields: Optional[Union[str, List[str raise ValueError(f"Component '{component_id}' not found in ComponentsManager") component = self.components[component_id] - + # Build complete info dict first info = { "model_id": component_id, "added_time": self.added_time[component_id], "collection": ", ".join([coll for coll, comps in self.collections.items() if component_id in comps]) or None, } - + # Additional info for torch.nn.Module components if isinstance(component, torch.nn.Module): # Check for hook information @@ -610,7 +605,7 @@ def get_model_info(self, component_id: str, fields: Optional[Union[str, List[str execution_device = None if has_hook and hasattr(component._hf_hook, "execution_device"): execution_device = component._hf_hook.execution_device - + info.update({ "class_name": component.__class__.__name__, "size_gb": get_memory_footprint(component) / (1024**3), @@ -631,8 +626,8 @@ def get_model_info(self, component_id: str, fields: Optional[Union[str, List[str if any("IPAdapter" in ptype for ptype in processor_types): # Then get scales only from IP-Adapter processors scales = { - k: v.scale - for k, v in processors.items() + k: v.scale + for k, v in processors.items() if hasattr(v, "scale") and "IPAdapter" in v.__class__.__name__ } if scales: @@ -646,7 +641,7 @@ def get_model_info(self, component_id: str, fields: Optional[Union[str, List[str else: # List of fields requested, return dict with just those fields return {k: v for k, v in info.items() if k in fields} - + return info def __repr__(self): @@ -659,13 +654,13 @@ def get_simple_name(name): if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]: return '_'.join(parts[:-1]) return name - + # Extract load_id if available def get_load_id(component): if hasattr(component, "_diffusers_load_id"): return component._diffusers_load_id return "N/A" - + # Format device info compactly def format_device(component, info): if not info["has_hook"]: @@ -674,18 +669,18 @@ def format_device(component, info): device = str(getattr(component, 'device', 'N/A')) exec_device = str(info['execution_device'] or 'N/A') return f"{device}({exec_device})" - + # Get all simple names to calculate width simple_names = [get_simple_name(id) for id in self.components.keys()] - + # Get max length of load_ids for models load_ids = [ - get_load_id(component) - for component in self.components.values() + get_load_id(component) + for component in self.components.values() if isinstance(component, torch.nn.Module) and hasattr(component, "_diffusers_load_id") ] max_load_id_len = max([15] + [len(str(lid)) for lid in load_ids]) if load_ids else 15 - + # Get all collections for each component component_collections = {} for name in self.components.keys(): @@ -695,11 +690,11 @@ def format_device(component, info): component_collections[name].append(coll) if not component_collections[name]: component_collections[name] = ["N/A"] - + # Find the maximum collection name length all_collections = [coll for colls in component_collections.values() for coll in colls] max_collection_len = max(10, max(len(str(c)) for c in all_collections)) if all_collections else 10 - + col_widths = { "name": max(15, max(len(name) for name in simple_names)), "class": max(25, max(len(component.__class__.__name__) for component in self.components.values())), @@ -736,21 +731,21 @@ def format_device(component, info): device_str = format_device(component, info) dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A" load_id = get_load_id(component) - + # Print first collection on the main line first_collection = component_collections[name][0] if component_collections[name] else "N/A" - + output += f"{simple_name:<{col_widths['name']}} | {info['class_name']:<{col_widths['class']}} | " output += f"{device_str:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | " output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {first_collection}\n" - + # Print additional collections on separate lines if they exist for i in range(1, len(component_collections[name])): collection = component_collections[name][i] output += f"{'':<{col_widths['name']}} | {'':<{col_widths['class']}} | " output += f"{'':<{col_widths['device']}} | {'':<{col_widths['dtype']}} | " output += f"{'':<{col_widths['size']}} | {'':<{col_widths['load_id']}} | {collection}\n" - + output += dash_line # Other components section @@ -766,17 +761,17 @@ def format_device(component, info): for name, component in others.items(): info = self.get_model_info(name) simple_name = get_simple_name(name) - + # Print first collection on the main line first_collection = component_collections[name][0] if component_collections[name] else "N/A" - + output += f"{simple_name:<{col_widths['name']}} | {component.__class__.__name__:<{col_widths['class']}} | {first_collection}\n" - + # Print additional collections on separate lines if they exist for i in range(1, len(component_collections[name])): collection = component_collections[name][i] output += f"{'':<{col_widths['name']}} | {'':<{col_widths['class']}} | {collection}\n" - + output += dash_line # Add additional component info @@ -789,8 +784,8 @@ def format_device(component, info): if info.get("adapters") is not None: output += f" Adapters: {info['adapters']}\n" if info.get("ip_adapter"): - output += f" IP-Adapter: Enabled\n" - + output += " IP-Adapter: Enabled\n" + return output def from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs): @@ -821,13 +816,13 @@ def from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = from ..pipelines.pipeline_utils import DiffusionPipeline pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs) for name, component in pipe.components.items(): - + if component is None: continue - + # Add prefix if specified component_name = f"{prefix}_{name}" if prefix else name - + if component_name not in self.components: self.add(component_name, component) else: @@ -860,15 +855,15 @@ def get_one(self, component_id: Optional[str] = None, name: Optional[str] = None if component_id not in self.components: raise ValueError(f"Component '{component_id}' not found in ComponentsManager") return self.components[component_id] - + results = self.get(name, collection, load_id) - + if not results: raise ValueError(f"No components found matching '{name}'") - + if len(results) > 1: raise ValueError(f"Multiple components found matching '{name}': {list(results.keys())}") - + return next(iter(results.values())) def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]: @@ -894,17 +889,17 @@ def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]: if value_tuple not in value_to_keys: value_to_keys[value_tuple] = [] value_to_keys[value_tuple].append(key) - + def find_common_prefix(keys: List[str]) -> str: """Find the shortest common prefix among a list of dot-separated keys.""" if not keys: return "" if len(keys) == 1: return keys[0] - + # Split all keys into parts key_parts = [k.split('.') for k in keys] - + # Find how many initial parts are common common_length = 0 for parts in zip(*key_parts): @@ -912,10 +907,10 @@ def find_common_prefix(keys: List[str]) -> str: common_length += 1 else: break - + if common_length == 0: return "" - + # Return the common prefix return '.'.join(key_parts[0][:common_length]) @@ -929,5 +924,5 @@ def find_common_prefix(keys: List[str]) -> str: summary[prefix] = value else: summary[""] = value # Use empty string if no common prefix - + return summary diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index cdb28519f440..0d7bec5a5cc2 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -11,49 +11,44 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import importlib import inspect - - +import os import traceback import warnings from collections import OrderedDict -from dataclasses import dataclass, field -from typing import Any, Dict, List, Tuple, Union, Optional -from typing_extensions import Self from copy import deepcopy - +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple, Union import torch -from tqdm.auto import tqdm -import re -import os -import importlib - from huggingface_hub.utils import validate_hf_hub_args +from tqdm.auto import tqdm +from typing_extensions import Self from ..configuration_utils import ConfigMixin, FrozenDict +from ..pipelines.pipeline_loading_utils import _fetch_class_library_tuple, simple_get_class_obj from ..utils import ( + PushToHubMixin, is_accelerate_available, logging, - PushToHubMixin, ) -from ..pipelines.pipeline_loading_utils import simple_get_class_obj, _fetch_class_library_tuple +from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code +from .components_manager import ComponentsManager from .modular_pipeline_utils import ( ComponentSpec, ConfigSpec, InputParam, + InsertableOrderedDict, OutputParam, format_components, format_configs, format_inputs_short, format_intermediates_short, make_doc_string, - InsertableOrderedDict ) -from .components_manager import ComponentsManager -from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code -from copy import deepcopy + if is_accelerate_available(): import accelerate @@ -118,7 +113,7 @@ def get_input(self, key: str, default: Any = None) -> Any: def get_inputs(self, keys: List[str], default: Any = None) -> Dict[str, Any]: return {key: self.inputs.get(key, default) for key in keys} - + def get_inputs_kwargs(self, kwargs_type: str) -> Dict[str, Any]: """ Get all inputs with matching kwargs_type. @@ -165,7 +160,7 @@ def format_value(v): inputs = "\n".join(f" {k}: {format_value(v)}" for k, v in self.inputs.items()) intermediates = "\n".join(f" {k}: {format_value(v)}" for k, v in self.intermediates.items()) - + # Format input_kwargs and intermediate_kwargs input_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.input_kwargs.items()) intermediate_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.intermediate_kwargs.items()) @@ -180,7 +175,7 @@ def format_value(v): ) -@dataclass +@dataclass class BlockState: """ Container for block state data with attribute access and formatted representation. @@ -192,11 +187,11 @@ def __init__(self, **kwargs): def __getitem__(self, key: str): # allows block_state["foo"] return getattr(self, key, None) - + def __setitem__(self, key: str, value: Any): # allows block_state["foo"] = "bar" setattr(self, key, value) - + def as_dict(self): """ Convert BlockState to a dictionary. @@ -211,21 +206,21 @@ def format_value(v): # Handle tensors directly if hasattr(v, "shape") and hasattr(v, "dtype"): return f"Tensor(dtype={v.dtype}, shape={v.shape})" - + # Handle lists of tensors elif isinstance(v, list): if len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): shapes = [t.shape for t in v] return f"List[{len(v)}] of Tensors with shapes {shapes}" return repr(v) - + # Handle tuples of tensors elif isinstance(v, tuple): if len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): shapes = [t.shape for t in v] return f"Tuple[{len(v)}] of Tensors with shapes {shapes}" return repr(v) - + # Handle dicts with tensor values elif isinstance(v, dict): formatted_dict = {} @@ -238,7 +233,7 @@ def format_value(v): else: formatted_dict[k] = repr(val) return formatted_dict - + # Default case return repr(v) @@ -261,7 +256,7 @@ def _get_signature_keys(cls, obj): expected_modules = set(required_parameters.keys()) - {"self"} return expected_modules, optional_parameters - + @classmethod def from_pretrained( @@ -311,17 +306,17 @@ def from_pretrained( def save_pretrained(self, save_directory, push_to_hub = False, **kwargs): # TODO: factor out this logic. cls_name = self.__class__.__name__ - - full_mod = type(self).__module__ + + full_mod = type(self).__module__ module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "") - parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0] + parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0] auto_map = {f"{parent_module}": f"{module}.{cls_name}"} - + self.register_to_config(auto_map=auto_map) self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) config = dict(self.config) self._internal_dict = FrozenDict(config) - + def init_pipeline(self, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None): """ create a ModularLoader, optionally accept modular_repo to load from hub. @@ -329,22 +324,22 @@ def init_pipeline(self, pretrained_model_name_or_path: Optional[Union[str, os.Pa loader_class_name = MODULAR_LOADER_MAPPING.get(self.model_name, ModularLoader.__name__) diffusers_module = importlib.import_module("diffusers") loader_class = getattr(diffusers_module, loader_class_name) - + # Create deep copies to avoid modifying the original specs component_specs = deepcopy(self.expected_components) config_specs = deepcopy(self.expected_configs) # Create the loader with the updated specs specs = component_specs + config_specs - + loader = loader_class(specs=specs, pretrained_model_name_or_path=pretrained_model_name_or_path, component_manager=component_manager, collection=collection) modular_pipeline = ModularPipeline(blocks=self, loader=loader) return modular_pipeline class PipelineBlock(ModularPipelineBlocks): - + model_name = None - + @property def description(self) -> str: """Description of the block. Must be implemented by subclasses.""" @@ -354,12 +349,12 @@ def description(self) -> str: @property def expected_components(self) -> List[ComponentSpec]: return [] - + @property def expected_configs(self) -> List[ConfigSpec]: return [] - + @property def inputs(self) -> List[InputParam]: """List of input parameters. Must be implemented by subclasses.""" @@ -394,7 +389,7 @@ def _get_required_inputs(self): @property def required_inputs(self) -> List[str]: return self._get_required_inputs() - + def _get_required_intermediates_inputs(self): input_names = [] @@ -403,7 +398,7 @@ def _get_required_intermediates_inputs(self): input_names.append(input_param.name) return input_names - # YiYi TODO: maybe we do not need this, it is only used in docstring, + # YiYi TODO: maybe we do not need this, it is only used in docstring, # intermediate_inputs is by default required, unless you manually handle it inside the block @property def required_intermediates_inputs(self) -> List[str]: @@ -460,9 +455,9 @@ def __repr__(self): @property def doc(self): return make_doc_string( - self.inputs, - self.intermediates_inputs, - self.outputs, + self.inputs, + self.intermediates_inputs, + self.outputs, self.description, class_name=self.__class__.__name__, expected_components=self.expected_components, @@ -474,7 +469,7 @@ def doc(self): def get_block_state(self, state: PipelineState) -> dict: """Get all inputs and intermediates in one dictionary""" data = {} - + # Check inputs for input_param in self.inputs: if input_param.name: @@ -514,14 +509,14 @@ def get_block_state(self, state: PipelineState) -> dict: data[k] = v data[input_param.kwargs_type][k] = v return BlockState(**data) - + def add_block_state(self, state: PipelineState, block_state: BlockState): for output_param in self.intermediates_outputs: if not hasattr(block_state, output_param.name): raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") param = getattr(block_state, output_param.name) state.add_intermediate(output_param.name, param, output_param.kwargs_type) - + for input_param in self.intermediates_inputs: if hasattr(block_state, input_param.name): param = getattr(block_state, input_param.name) @@ -561,7 +556,7 @@ def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> Li """ combined_dict = {} # name -> InputParam value_sources = {} # name -> block_name - + for block_name, inputs in named_input_lists: for input_param in inputs: if input_param.name is None and input_param.kwargs_type is not None: @@ -570,8 +565,8 @@ def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> Li input_name = input_param.name if input_name in combined_dict: current_param = combined_dict[input_name] - if (current_param.default is not None and - input_param.default is not None and + if (current_param.default is not None and + input_param.default is not None and current_param.default != input_param.default): warnings.warn( f"Multiple different default values found for input '{input_name}': " @@ -584,7 +579,7 @@ def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> Li else: combined_dict[input_name] = input_param value_sources[input_name] = block_name - + return list(combined_dict.values()) def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> List[OutputParam]: @@ -599,12 +594,12 @@ def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> List[OutputParam]: Combined list of unique OutputParam objects """ combined_dict = {} # name -> OutputParam - + for block_name, outputs in named_output_lists: for output_param in outputs: if (output_param.name not in combined_dict) or (combined_dict[output_param.name].kwargs_type is None and output_param.kwargs_type is not None): combined_dict[output_param.name] = output_param - + return list(combined_dict.values()) @@ -630,7 +625,7 @@ def __init__(self): if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)): raise ValueError(f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same.") default_blocks = [t for t in self.block_trigger_inputs if t is None] - # can only have 1 or 0 default block, and has to put in the last + # can only have 1 or 0 default block, and has to put in the last # the order of blocksmatters here because the first block with matching trigger will be dispatched # e.g. blocks = [inpaint, img2img] and block_trigger_inputs = ["mask", "image"] # if both mask and image are provided, it is inpaint; if only image is provided, it is img2img @@ -650,7 +645,7 @@ def __init__(self): @property def model_name(self): return next(iter(self.blocks.values())).model_name - + @property def description(self): return "" @@ -687,8 +682,8 @@ def required_inputs(self) -> List[str]: required_by_all.intersection_update(block_required) return list(required_by_all) - - # YiYi TODO: maybe we do not need this, it is only used in docstring, + + # YiYi TODO: maybe we do not need this, it is only used in docstring, # intermediate_inputs is by default required, unless you manually handle it inside the block @property def required_intermediates_inputs(self) -> List[str]: @@ -736,7 +731,7 @@ def intermediates_outputs(self) -> List[str]: named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] combined_outputs = combine_outputs(*named_outputs) return combined_outputs - + @property def outputs(self) -> List[str]: named_outputs = [(name, block.outputs) for name, block in self.blocks.items()] @@ -779,24 +774,24 @@ def _get_trigger_inputs(self): """ def fn_recursive_get_trigger(blocks): trigger_values = set() - + if blocks is not None: for name, block in blocks.items(): # Check if current block has trigger inputs(i.e. auto block) if hasattr(block, 'block_trigger_inputs') and block.block_trigger_inputs is not None: # Add all non-None values from the trigger inputs list trigger_values.update(t for t in block.block_trigger_inputs if t is not None) - + # If block has blocks, recursively check them if hasattr(block, 'blocks'): nested_triggers = fn_recursive_get_trigger(block.blocks) trigger_values.update(nested_triggers) - + return trigger_values - + trigger_inputs = set(self.block_trigger_inputs) trigger_inputs.update(fn_recursive_get_trigger(self.blocks)) - + return trigger_inputs @property @@ -812,7 +807,7 @@ def __repr__(self): else f"{class_name}(\n" ) - + if self.trigger_inputs: header += "\n" header += " " + "=" * 100 + "\n" @@ -836,7 +831,7 @@ def __repr__(self): # Components section - focus only on expected components expected_components = getattr(self, "expected_components", []) components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) - + # Configs section - use format_configs with add_empty_lines=False expected_configs = getattr(self, "expected_configs", []) configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) @@ -860,7 +855,7 @@ def __repr__(self): else: # For SequentialPipelineBlocks, show execution order blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" - + # Add block description desc_lines = block.description.split('\n') indented_desc = desc_lines[0] @@ -870,27 +865,27 @@ def __repr__(self): # Build the representation with conditional sections result = f"{header}\n{desc}" - + # Only add components section if it has content if components_str.strip(): result += f"\n\n{components_str}" - + # Only add configs section if it has content if configs_str.strip(): result += f"\n\n{configs_str}" - + # Always add blocks section result += f"\n\n{blocks_str})" - + return result @property def doc(self): return make_doc_string( - self.inputs, - self.intermediates_inputs, - self.outputs, + self.inputs, + self.intermediates_inputs, + self.outputs, self.description, class_name=self.__class__.__name__, expected_components=self.expected_components, @@ -905,15 +900,15 @@ class SequentialPipelineBlocks(ModularPipelineBlocks): block_classes = [] block_names = [] - + @property def description(self): return "" - + @property def model_name(self): return next(iter(self.blocks.values())).model_name - + @property def expected_components(self): @@ -944,7 +939,7 @@ def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "SequentialPipelineBlo A new SequentialPipelineBlocks instance """ instance = cls() - + # Create instances if classes are provided blocks = InsertableOrderedDict() for name, block in blocks_dict.items(): @@ -952,12 +947,12 @@ def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "SequentialPipelineBlo blocks[name] = block() else: blocks[name] = block - + instance.block_classes = [block.__class__ for block in blocks.values()] instance.block_names = list(blocks.keys()) instance.blocks = blocks return instance - + def __init__(self): blocks = InsertableOrderedDict() for block_name, block_cls in zip(self.block_names, self.block_classes): @@ -975,10 +970,10 @@ def required_inputs(self) -> List[str]: for block in list(self.blocks.values())[1:]: block_required = set(getattr(block, "required_inputs", set())) required_by_any.update(block_required) - + return list(required_by_any) - - # YiYi TODO: maybe we do not need this, it is only used in docstring, + + # YiYi TODO: maybe we do not need this, it is only used in docstring, # intermediate_inputs is by default required, unless you manually handle it inside the block @property def required_intermediates_inputs(self) -> List[str]: @@ -1007,7 +1002,7 @@ def get_inputs(self): @property def intermediates_inputs(self) -> List[str]: return self.get_intermediates_inputs() - + def get_intermediates_inputs(self): inputs = [] outputs = set() @@ -1025,7 +1020,7 @@ def get_intermediates_inputs(self): should_add_outputs = True if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs: should_add_outputs = False - + if should_add_outputs: # Add this block's outputs block_intermediates_outputs = [out.name for out in block.intermediates_outputs] @@ -1043,7 +1038,7 @@ def intermediates_outputs(self) -> List[str]: named_outputs.append((name, block.intermediates_outputs)) combined_outputs = combine_outputs(*named_outputs) return combined_outputs - + # YiYi TODO: I think we can remove the outputs property @property def outputs(self) -> List[str]: @@ -1063,7 +1058,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: logger.error(error_msg) raise return pipeline, state - + def _get_trigger_inputs(self): """ Returns a set of all unique trigger input values found in the blocks. @@ -1071,21 +1066,21 @@ def _get_trigger_inputs(self): """ def fn_recursive_get_trigger(blocks): trigger_values = set() - + if blocks is not None: for name, block in blocks.items(): # Check if current block has trigger inputs(i.e. auto block) if hasattr(block, 'block_trigger_inputs') and block.block_trigger_inputs is not None: # Add all non-None values from the trigger inputs list trigger_values.update(t for t in block.block_trigger_inputs if t is not None) - + # If block has blocks, recursively check them if hasattr(block, 'blocks'): nested_triggers = fn_recursive_get_trigger(block.blocks) trigger_values.update(nested_triggers) - + return trigger_values - + return fn_recursive_get_trigger(self.blocks) @property @@ -1097,7 +1092,7 @@ def _traverse_trigger_blocks(self, trigger_inputs): active_triggers = set(trigger_inputs) def fn_recursive_traverse(block, block_name, active_triggers): result_blocks = OrderedDict() - + # sequential(include loopsequential) or PipelineBlock if not hasattr(block, 'block_trigger_inputs'): if hasattr(block, 'blocks'): @@ -1114,7 +1109,7 @@ def fn_recursive_traverse(block, block_name, active_triggers): if hasattr(block, 'outputs'): active_triggers.update(out.name for out in block.outputs) return result_blocks - + # auto else: # Find first block_trigger_input that matches any value in our active_triggers @@ -1125,12 +1120,12 @@ def fn_recursive_traverse(block, block_name, active_triggers): this_block = block.trigger_to_block_map[trigger_input] matching_trigger = trigger_input break - + # If no matches found, try to get the default (None) block if this_block is None and None in block.block_trigger_inputs: this_block = block.trigger_to_block_map[None] matching_trigger = None - + if this_block is not None: # sequential/auto (keep traversing) if hasattr(this_block, 'blocks'): @@ -1144,13 +1139,13 @@ def fn_recursive_traverse(block, block_name, active_triggers): active_triggers.update(out.name for out in this_block.outputs) return result_blocks - + all_blocks = OrderedDict() for block_name, block in self.blocks.items(): blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers) all_blocks.update(blocks_to_update) return all_blocks - + def get_execution_blocks(self, *trigger_inputs): trigger_inputs_all = self.trigger_inputs @@ -1164,7 +1159,7 @@ def get_execution_blocks(self, *trigger_inputs): f"The following trigger inputs will be ignored as they are not supported: {invalid_inputs}" ) trigger_inputs = [x for x in trigger_inputs if x in trigger_inputs_all] - + if trigger_inputs is None: if None in trigger_inputs_all: trigger_inputs = [None] @@ -1172,7 +1167,7 @@ def get_execution_blocks(self, *trigger_inputs): trigger_inputs = [trigger_inputs_all[0]] blocks_triggered = self._traverse_trigger_blocks(trigger_inputs) return SequentialPipelineBlocks.from_blocks_dict(blocks_triggered) - + def __repr__(self): class_name = self.__class__.__name__ base_class = self.__class__.__bases__[0].__name__ @@ -1182,7 +1177,7 @@ def __repr__(self): else f"{class_name}(\n" ) - + if self.trigger_inputs: header += "\n" header += " " + "=" * 100 + "\n" @@ -1206,7 +1201,7 @@ def __repr__(self): # Components section - focus only on expected components expected_components = getattr(self, "expected_components", []) components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) - + # Configs section - use format_configs with add_empty_lines=False expected_configs = getattr(self, "expected_configs", []) configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) @@ -1230,7 +1225,7 @@ def __repr__(self): else: # For SequentialPipelineBlocks, show execution order blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" - + # Add block description desc_lines = block.description.split('\n') indented_desc = desc_lines[0] @@ -1240,27 +1235,27 @@ def __repr__(self): # Build the representation with conditional sections result = f"{header}\n{desc}" - + # Only add components section if it has content if components_str.strip(): result += f"\n\n{components_str}" - + # Only add configs section if it has content if configs_str.strip(): result += f"\n\n{configs_str}" - + # Always add blocks section result += f"\n\n{blocks_str})" - + return result @property def doc(self): return make_doc_string( - self.inputs, - self.intermediates_inputs, - self.outputs, + self.inputs, + self.intermediates_inputs, + self.outputs, self.description, class_name=self.__class__.__name__, expected_components=self.expected_components, @@ -1276,7 +1271,7 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks): model_name = None block_classes = [] block_names = [] - + @property def description(self) -> str: """Description of the block. Must be implemented by subclasses.""" @@ -1285,7 +1280,7 @@ def description(self) -> str: @property def loop_expected_components(self) -> List[ComponentSpec]: return [] - + @property def loop_expected_configs(self) -> List[ConfigSpec]: return [] @@ -1365,8 +1360,8 @@ def get_inputs(self): @property def inputs(self): return self.get_inputs() - - + + # modified from SequentialPipelineBlocks to include loop_intermediates_inputs @property def intermediates_inputs(self): @@ -1392,7 +1387,7 @@ def get_intermediates_inputs(self): should_add_outputs = True if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs: should_add_outputs = False - + if should_add_outputs: # Add this block's outputs block_intermediates_outputs = [out.name for out in block.intermediates_outputs] @@ -1414,10 +1409,10 @@ def required_inputs(self) -> List[str]: for block in list(self.blocks.values())[1:]: block_required = set(getattr(block, "required_inputs", set())) required_by_any.update(block_required) - + return list(required_by_any) - # YiYi TODO: maybe we do not need this, it is only used in docstring, + # YiYi TODO: maybe we do not need this, it is only used in docstring, # intermediate_inputs is by default required, unless you manually handle it inside the block @property def required_intermediates_inputs(self) -> List[str]: @@ -1441,7 +1436,7 @@ def intermediates_outputs(self) -> List[str]: if output.name not in set([output.name for output in combined_outputs]): combined_outputs.append(output) return combined_outputs - + # YiYi TODO: this need to be thought about more # copied from SequentialPipelineBlocks @property @@ -1454,7 +1449,7 @@ def __init__(self): for block_name, block_cls in zip(self.block_names, self.block_classes): blocks[block_name] = block_cls() self.blocks = blocks - + @classmethod def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "LoopSequentialPipelineBlocks": """Creates a LoopSequentialPipelineBlocks instance from a dictionary of blocks. @@ -1485,15 +1480,15 @@ def loop_step(self, components, state: PipelineState, **kwargs): logger.error(error_msg) raise return components, state - + def __call__(self, components, state: PipelineState) -> PipelineState: raise NotImplementedError("`__call__` method needs to be implemented by the subclass") - - + + def get_block_state(self, state: PipelineState) -> dict: """Get all inputs and intermediates in one dictionary""" data = {} - + # Check inputs for input_param in self.inputs: if input_param.name: @@ -1533,7 +1528,7 @@ def get_block_state(self, state: PipelineState) -> dict: data[k] = v data[input_param.kwargs_type][k] = v return BlockState(**data) - + def add_block_state(self, state: PipelineState, block_state: BlockState): for output_param in self.intermediates_outputs: if not hasattr(block_state, output_param.name): @@ -1563,17 +1558,17 @@ def add_block_state(self, state: PipelineState, block_state: BlockState): @property def doc(self): return make_doc_string( - self.inputs, - self.intermediates_inputs, - self.outputs, + self.inputs, + self.intermediates_inputs, + self.outputs, self.description, class_name=self.__class__.__name__, expected_components=self.expected_components, expected_configs=self.expected_configs ) - # modified from SequentialPipelineBlocks, - #(does not need trigger_inputs related part so removed them, + # modified from SequentialPipelineBlocks, + #(does not need trigger_inputs related part so removed them, # do not need to support auto block for loop blocks) def __repr__(self): class_name = self.__class__.__name__ @@ -1597,7 +1592,7 @@ def __repr__(self): # Components section - focus only on expected components expected_components = getattr(self, "expected_components", []) components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) - + # Configs section - use format_configs with add_empty_lines=False expected_configs = getattr(self, "expected_configs", []) configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) @@ -1605,10 +1600,10 @@ def __repr__(self): # Blocks section - moved to the end with simplified format blocks_str = " Blocks:\n" for i, (name, block) in enumerate(self.blocks.items()): - + # For SequentialPipelineBlocks, show execution order blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" - + # Add block description desc_lines = block.description.split('\n') indented_desc = desc_lines[0] @@ -1618,18 +1613,18 @@ def __repr__(self): # Build the representation with conditional sections result = f"{header}\n{desc}" - + # Only add components section if it has content if components_str.strip(): result += f"\n\n{components_str}" - + # Only add configs section if it has content if configs_str.strip(): result += f"\n\n{configs_str}" - + # Always add blocks section result += f"\n\n{blocks_str})" - + return result @torch.compiler.disable @@ -1652,7 +1647,7 @@ def set_progress_bar_config(self, **kwargs): self._progress_bar_config = kwargs -# YiYi TODO: +# YiYi TODO: # 1. move the modular_repo arg and the logic to fetch info from repo out of __init__ so that __init__ alwasy create an default modular_model_index config # 2. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess) # 3. do we need ConfigSpec? seems pretty unnecessrary for loader, can just add and kwargs to the loader @@ -1696,29 +1691,29 @@ def register_components(self, **kwargs): if component_spec is None: logger.warning(f"ModularLoader.register_components: skipping unknown component '{name}'") continue - + # check if it is the first time registration, i.e. calling from __init__ is_registered = hasattr(self, name) # make sure the component is created from ComponentSpec if module is not None and not hasattr(module, "_diffusers_load_id"): - raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") + raise ValueError("`ModularLoader` only supports components created from `ComponentSpec`.") if module is not None: # actual library and class name of the module library, class_name = _fetch_class_library_tuple(module) # e.g. ("diffusers", "UNet2DConditionModel") # extract the loading spec from the updated component spec that'll be used as part of modular_model_index.json config - # e.g. {"repo": "stabilityai/stable-diffusion-2-1", - # "type_hint": ("diffusers", "UNet2DConditionModel"), + # e.g. {"repo": "stabilityai/stable-diffusion-2-1", + # "type_hint": ("diffusers", "UNet2DConditionModel"), # "subfolder": "unet", # "variant": None, # "revision": None} component_spec_dict = self._component_spec_to_dict(component_spec) - + else: # if module is None, e.g. self.register_components(unet=None) during __init__ - # we do not update the spec, + # we do not update the spec, # but we still need to update the modular_model_index.json config based oncomponent spec library, class_name = None, None component_spec_dict = self._component_spec_to_dict(component_spec) @@ -1732,7 +1727,7 @@ def register_components(self, **kwargs): if module is not None and module._diffusers_load_id != "null" and self._component_manager is not None: self._component_manager.add(name, module, self._collection) continue - + current_module = getattr(self, name, None) # skip if the component is already registered with the same object if current_module is module: @@ -1764,7 +1759,7 @@ def register_components(self, **kwargs): self._component_manager.add(name, module, self._collection) - + # YiYi TODO: add warning for passing multiple ComponentSpec/ConfigSpec with the same name def __init__(self, specs: List[Union[ComponentSpec, ConfigSpec]], pretrained_model_name_or_path: Optional[str] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs): """ @@ -1792,7 +1787,7 @@ def __init__(self, specs: List[Union[ComponentSpec, ConfigSpec]], pretrained_mod elif name in self._config_specs: self._config_specs[name].default = value - + register_components_dict = {} for name, component_spec in self._component_specs.items(): if component_spec.default_creation_method == "from_config": @@ -1801,7 +1796,7 @@ def __init__(self, specs: List[Union[ComponentSpec, ConfigSpec]], pretrained_mod component = None register_components_dict[name] = component self.register_components(**register_components_dict) - + default_configs = {} for name, config_spec in self._config_specs.items(): default_configs[name] = config_spec.default @@ -1844,7 +1839,7 @@ def _execution_device(self): ): return torch.device(module._hf_hook.execution_device) return self.device - + @property def dtype(self) -> torch.dtype: @@ -1871,12 +1866,6 @@ def components(self) -> Dict[str, Any]: } def update(self, **kwargs): - """ - Update components and configs after instance creation. - - Args: - - """ """ Update components and configuration values after the loader has been instantiated. @@ -1917,7 +1906,7 @@ def update(self, **kwargs): guider=ComponentSpec(name="guider", type_hint=ClassifierFreeGuidance, config={"guidance_scale": 5.0}, default_creation_method="from_config") ) ``` - """ + """ # extract component_specs_updates & config_specs_updates from `specs` passed_component_specs = {k: kwargs.pop(k) for k in self._component_specs if k in kwargs and isinstance(kwargs[k], ComponentSpec)} @@ -1926,7 +1915,7 @@ def update(self, **kwargs): for name, component in passed_components.items(): if not hasattr(component, "_diffusers_load_id"): - raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") + raise ValueError("`ModularLoader` only supports components created from `ComponentSpec`.") # YiYi TODO: remove this if we remove support for non config mixin components in `create()` method if component._diffusers_load_id == "null" and not isinstance(component, ConfigMixin): @@ -1942,14 +1931,14 @@ def update(self, **kwargs): # update _component_specs based on the new component new_component_spec = ComponentSpec.from_component(name, component) self._component_specs[name] = new_component_spec - + if len(kwargs) > 0: logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}") - + created_components = {} for name, component_spec in passed_component_specs.items(): if component_spec.default_creation_method == "from_pretrained": - raise ValueError(f"ComponentSpec object with default_creation_method == 'from_pretrained' is not supported in update() method") + raise ValueError("ComponentSpec object with default_creation_method == 'from_pretrained' is not supported in update() method") created_components[name] = component_spec.create() current_component_spec = self._component_specs[name] # warn if type changed @@ -1991,7 +1980,7 @@ def load(self, component_names: Optional[List[str]] = None, **kwargs): unknown_component_names = set([name for name in component_names if name not in self._component_specs]) if len(unknown_component_names) > 0: logger.warning(f"Unknown components will be ignored: {unknown_component_names}") - + components_to_register = {} for name in components_to_load: spec = self._component_specs[name] @@ -2011,7 +2000,7 @@ def load(self, component_names: Optional[List[str]] = None, **kwargs): components_to_register[name] = spec.load(**component_load_kwargs) except Exception as e: logger.warning(f"Failed to create component '{name}': {e}") - + # Register all components at once self.register_components(**components_to_register) @@ -2033,7 +2022,7 @@ def _maybe_raise_error_if_group_offload_active( ) return True return False - + # Modified from diffusers.pipelines.pipeline_utils.DiffusionPipeline.to def to(self, *args, **kwargs) -> Self: r""" @@ -2071,7 +2060,7 @@ def to(self, *args, **kwargs) -> Self: Returns: [`DiffusionPipeline`]: The pipeline converted to specified `dtype` and/or `dtype`. """ - from ..pipelines.pipeline_utils import _check_bnb_status, DiffusionPipeline + from ..pipelines.pipeline_utils import _check_bnb_status from ..utils import is_accelerate_available, is_accelerate_version, is_hpu_available, is_transformers_version @@ -2227,7 +2216,7 @@ def module_is_offloaded(module): ) return self - # YiYi TODO: + # YiYi TODO: # 1. should support save some components too! currently only modular_model_index.json is saved # 2. maybe order the json file to make it more readable: configs first, then components def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, spec_only: bool = True, **kwargs): @@ -2241,11 +2230,11 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: config.pop("_configs_names", None) self._internal_dict = FrozenDict(config) - + @classmethod @validate_hf_hub_args def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], spec_only: bool = True, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs): - + config_dict = cls.load_config(pretrained_model_name_or_path, **kwargs) expected_component = set(config_dict.pop("_components_names")) expected_config = set(config_dict.pop("_configs_names")) @@ -2265,7 +2254,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P return cls(component_specs + config_specs, component_manager=component_manager, collection=collection) - + @staticmethod def _component_spec_to_dict(component_spec: ComponentSpec) -> Any: """ @@ -2432,33 +2421,33 @@ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = else: raise ValueError(f"Output '{output}' is not a valid output type") - + def load_components(self, component_names: Optional[List[str]] = None, **kwargs): self.loader.load(component_names=component_names, **kwargs) - + def update_components(self, **kwargs): self.loader.update(**kwargs) - + @classmethod @validate_hf_hub_args def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], trust_remote_code: Optional[bool] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs): blocks = ModularPipelineBlocks.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs) pipeline = blocks.init_pipeline(pretrained_model_name_or_path, component_manager=component_manager, collection=collection, **kwargs) return pipeline - + def save_pretrained(self, save_directory: Optional[Union[str, os.PathLike]] = None, push_to_hub: bool = False, **kwargs): self.blocks.save_pretrained(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) self.loader.save_pretrained(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) - - + + @property def doc(self): return self.blocks.doc - + def to(self, *args, **kwargs): self.loader.to(*args, **kwargs) return self - + @property def components(self): - return self.loader.components \ No newline at end of file + return self.loader.components diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index c83b2abf50a7..1b9874bb52bd 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -12,44 +12,45 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re import inspect -from dataclasses import dataclass, asdict, field, fields -from typing import Any, Dict, List, Optional, Tuple, Type, Union, Literal +import re +from collections import OrderedDict +from dataclasses import dataclass, field, fields +from typing import Any, Dict, List, Literal, Optional, Type, Union +from ..configuration_utils import ConfigMixin, FrozenDict from ..utils.import_utils import is_torch_available -from ..configuration_utils import FrozenDict, ConfigMixin -from collections import OrderedDict + if is_torch_available(): - import torch + pass class InsertableOrderedDict(OrderedDict): def insert(self, key, value, index): items = list(self.items()) - + # Remove key if it already exists to avoid duplicates items = [(k, v) for k, v in items if k != key] - + # Insert at the specified index items.insert(index, (key, value)) - + # Clear and update self self.clear() self.update(items) - + # Return self for method chaining return self - + def __repr__(self): if not self: return "InsertableOrderedDict()" - + items = [] for i, (key, value) in enumerate(self.items()): items.append(f"{i}: ({repr(key)}, {repr(value)})") - + return "InsertableOrderedDict([\n " + ",\n ".join(items) + "\n])" @@ -85,24 +86,24 @@ class ComponentSpec: variant: Optional[str] = field(default=None, metadata={"loading": True}) revision: Optional[str] = field(default=None, metadata={"loading": True}) default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained" - - + + def __hash__(self): """Make ComponentSpec hashable, using load_id as the hash value.""" return hash((self.name, self.load_id, self.default_creation_method)) - + def __eq__(self, other): """Compare ComponentSpec objects based on name and load_id.""" if not isinstance(other, ComponentSpec): return False - return (self.name == other.name and - self.load_id == other.load_id and + return (self.name == other.name and + self.load_id == other.load_id and self.default_creation_method == other.default_creation_method) - + @classmethod def from_component(cls, name: str, component: Any) -> Any: """Create a ComponentSpec from a Component created by `create` or `load` method.""" - + if not hasattr(component, "_diffusers_load_id"): raise ValueError("Component is not created by `create` or `load` method") # throw a error if component is created with `create` method but not a subclass of ConfigMixin @@ -113,19 +114,19 @@ def from_component(cls, name: str, component: Any) -> Any: "created with `ComponentSpec.load` method" "or created with `ComponentSpec.create` and a subclass of ConfigMixin" ) - + type_hint = component.__class__ default_creation_method = "from_config" if component._diffusers_load_id == "null" else "from_pretrained" - + if isinstance(component, ConfigMixin) and default_creation_method == "from_config": config = component.config else: config = None - + load_spec = cls.decode_load_id(component._diffusers_load_id) - + return cls(name=name, type_hint=type_hint, config=config, default_creation_method=default_creation_method, **load_spec) - + @classmethod def loading_fields(cls) -> List[str]: """ @@ -133,8 +134,8 @@ def loading_fields(cls) -> List[str]: (i.e. those whose field.metadata["loading"] is True). """ return [f.name for f in fields(cls) if f.metadata.get("loading", False)] - - + + @property def load_id(self) -> str: """ @@ -144,7 +145,7 @@ def load_id(self) -> str: parts = [getattr(self, k) for k in self.loading_fields()] parts = ["null" if p is None else p for p in parts] return "|".join(p for p in parts if p) - + @classmethod def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]: """ @@ -165,26 +166,26 @@ def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]: If a segment value is "null", it's replaced with None. Returns None if load_id is "null" (indicating component not created with `load` method). """ - + # Get all loading fields in order loading_fields = cls.loading_fields() result = {f: None for f in loading_fields} if load_id == "null": return result - + # Split the load_id parts = load_id.split("|") - + # Map parts to loading fields by position for i, part in enumerate(parts): if i < len(loading_fields): # Convert "null" string back to None result[loading_fields[i]] = None if part == "null" else part - + return result - - + + # YiYi TODO: I think we should only support ConfigMixin for this method (after we make guider and image_processors config mixin) # otherwise we cannot do spec -> spec.create() -> component -> ComponentSpec.from_component(component) # the config info is lost in the process @@ -194,11 +195,11 @@ def create(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **k if self.type_hint is None or not isinstance(self.type_hint, type): raise ValueError( - f"`type_hint` is required when using from_config creation method." + "`type_hint` is required when using from_config creation method." ) - + config = config or self.config or {} - + if issubclass(self.type_hint, ConfigMixin): component = self.type_hint.from_config(config, **kwargs) else: @@ -211,17 +212,17 @@ def create(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **k if k in signature_params: init_kwargs[k] = v component = self.type_hint(**init_kwargs) - + component._diffusers_load_id = "null" if hasattr(component, "config"): self.config = component.config - + return component - + # YiYi TODO: add guard for type of model, if it is supported by from_pretrained def load(self, **kwargs) -> Any: """Load component using from_pretrained.""" - + # select loading fields from kwargs passed from user: e.g. repo, subfolder, variant, revision, note the list could change passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs} # merge loading field value in the spec with user passed values to create load_kwargs @@ -229,8 +230,8 @@ def load(self, **kwargs) -> Any: # repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path repo = load_kwargs.pop("repo", None) if repo is None: - raise ValueError(f"`repo` info is required when using `load` method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)") - + raise ValueError("`repo` info is required when using `load` method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)") + if self.type_hint is None: try: from diffusers import AutoModel @@ -244,17 +245,17 @@ def load(self, **kwargs) -> Any: component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs) except Exception as e: raise ValueError(f"Unable to load {self.name} using load method: {e}") - + self.repo = repo for k, v in load_kwargs.items(): setattr(self, k, v) component._diffusers_load_id = self.load_id - + return component - -@dataclass + +@dataclass class ConfigSpec: """Specification for a pipeline configuration parameter.""" name: str @@ -281,7 +282,7 @@ def __repr__(self): return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>" -@dataclass +@dataclass class OutputParam: """Specification for an output parameter.""" name: str @@ -315,14 +316,14 @@ def format_inputs_short(inputs): """ required_inputs = [param for param in inputs if param.required] optional_inputs = [param for param in inputs if not param.required] - + required_str = ", ".join(param.name for param in required_inputs) optional_str = ", ".join(f"{param.name}={param.default}" for param in optional_inputs) - + inputs_str = required_str if optional_str: inputs_str = f"{inputs_str}, {optional_str}" if required_str else optional_str - + return inputs_str @@ -353,18 +354,18 @@ def format_intermediates_short(intermediates_inputs, required_intermediates_inpu else: inp_name = inp.name input_parts.append(inp_name) - + # Handle modified variables (appear in both inputs and outputs) inputs_set = {inp.name for inp in intermediates_inputs} modified_parts = [] new_output_parts = [] - + for out in intermediates_outputs: if out.name in inputs_set: modified_parts.append(out.name) else: new_output_parts.append(out.name) - + result = [] if input_parts: result.append(f" - inputs: {', '.join(input_parts)}") @@ -372,7 +373,7 @@ def format_intermediates_short(intermediates_inputs, required_intermediates_inpu result.append(f" - modified: {', '.join(modified_parts)}") if new_output_parts: result.append(f" - outputs: {', '.join(new_output_parts)}") - + return "\n".join(result) if result else " (none)" @@ -390,18 +391,18 @@ def format_params(params, header="Args", indent_level=4, max_line_length=115): """ if not params: return "" - + base_indent = " " * indent_level param_indent = " " * (indent_level + 4) desc_indent = " " * (indent_level + 8) formatted_params = [] - + def get_type_str(type_hint): if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Union: types = [t.__name__ if hasattr(t, "__name__") else str(t) for t in type_hint.__args__] return f"Union[{', '.join(types)}]" return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint) - + def wrap_text(text, indent, max_length): """Wrap text while preserving markdown links and maintaining indentation.""" words = text.split() @@ -411,7 +412,7 @@ def wrap_text(text, indent, max_length): for word in words: word_length = len(word) + (1 if current_line else 0) - + if current_line and current_length + word_length > max_length: lines.append(" ".join(current_line)) current_line = [word] @@ -419,22 +420,22 @@ def wrap_text(text, indent, max_length): else: current_line.append(word) current_length += word_length - + if current_line: lines.append(" ".join(current_line)) - + return f"\n{indent}".join(lines) - + # Add the header formatted_params.append(f"{base_indent}{header}:") - + for param in params: # Format parameter name and type type_str = get_type_str(param.type_hint) if param.type_hint != Any else "" # YiYi Notes: remove this line if we remove kwargs_type name = f'**{param.kwargs_type}' if param.name is None and param.kwargs_type is not None else param.name param_str = f"{param_indent}{name} (`{type_str}`" - + # Add optional tag and default value if parameter is an InputParam and optional if hasattr(param, "required"): if not param.required: @@ -442,7 +443,7 @@ def wrap_text(text, indent, max_length): if param.default is not None: param_str += f", defaults to {param.default}" param_str += "):" - + # Add description on a new line with additional indentation and wrapping if param.description: desc = re.sub( @@ -452,9 +453,9 @@ def wrap_text(text, indent, max_length): ) wrapped_desc = wrap_text(desc, desc_indent, max_line_length) param_str += f"\n{desc_indent}{wrapped_desc}" - + formatted_params.append(param_str) - + return "\n\n".join(formatted_params) @@ -500,42 +501,42 @@ def format_components(components, indent_level=4, max_line_length=115, add_empty """ if not components: return "" - + base_indent = " " * indent_level component_indent = " " * (indent_level + 4) formatted_components = [] - + # Add the header formatted_components.append(f"{base_indent}Components:") if add_empty_lines: formatted_components.append("") - + # Add each component with optional empty lines between them for i, component in enumerate(components): # Get type name, handling special cases type_name = component.type_hint.__name__ if hasattr(component.type_hint, "__name__") else str(component.type_hint) - + component_desc = f"{component_indent}{component.name} (`{type_name}`)" if component.description: component_desc += f": {component.description}" - + # Get the loading fields dynamically loading_field_values = [] for field_name in component.loading_fields(): field_value = getattr(component, field_name) if field_value is not None: loading_field_values.append(f"{field_name}={field_value}") - + # Add loading field information if available if loading_field_values: component_desc += f" [{', '.join(loading_field_values)}]" - + formatted_components.append(component_desc) - + # Add an empty line after each component except the last one if add_empty_lines and i < len(components) - 1: formatted_components.append("") - + return "\n".join(formatted_components) @@ -553,27 +554,27 @@ def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines """ if not configs: return "" - + base_indent = " " * indent_level config_indent = " " * (indent_level + 4) formatted_configs = [] - + # Add the header formatted_configs.append(f"{base_indent}Configs:") if add_empty_lines: formatted_configs.append("") - + # Add each config with optional empty lines between them for i, config in enumerate(configs): config_desc = f"{config_indent}{config.name} (default: {config.default})" if config.description: config_desc += f": {config.description}" formatted_configs.append(config_desc) - + # Add an empty line after each config except the last one if add_empty_lines and i < len(configs) - 1: formatted_configs.append("") - + return "\n".join(formatted_configs) @@ -618,9 +619,9 @@ def make_doc_string(inputs, intermediates_inputs, outputs, description="", class # Add inputs section output += format_input_params(inputs + intermediates_inputs, indent_level=2) - + # Add outputs section output += "\n\n" output += format_output_params(outputs, indent_level=2) - return output \ No newline at end of file + return output diff --git a/src/diffusers/modular_pipelines/node_utils.py b/src/diffusers/modular_pipelines/node_utils.py index 5f5e1c6c782d..4855a9bcfcd1 100644 --- a/src/diffusers/modular_pipelines/node_utils.py +++ b/src/diffusers/modular_pipelines/node_utils.py @@ -1,16 +1,19 @@ -from ..configuration_utils import ConfigMixin -from .modular_pipeline import SequentialPipelineBlocks, ModularPipelineBlocks -from .modular_pipeline_utils import InputParam, OutputParam -from ..image_processor import PipelineImageInput -from pathlib import Path import json +import logging import os +from pathlib import Path +from typing import List, Optional, Tuple, Union -from typing import Union, List, Optional, Tuple -import torch -import PIL import numpy as np -import logging +import PIL +import torch + +from ..configuration_utils import ConfigMixin +from ..image_processor import PipelineImageInput +from .modular_pipeline import ModularPipelineBlocks, SequentialPipelineBlocks +from .modular_pipeline_utils import InputParam + + logger = logging.getLogger(__name__) # YiYi Notes: this is actually for SDXL, put it here for now @@ -189,8 +192,8 @@ def get_group_name(name, group_params_keys=DEFAULT_PARAMS_GROUPS_KEYS): if group_key in name: return group_name return None - - + + class ModularNode(ConfigMixin): config_name = "node_config.json" @@ -214,15 +217,15 @@ def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs): self.name_mapping = {} input_params = {} - # pass or create a default param dict for each input + # pass or create a default param dict for each input # e.g. for prompt, # prompt = { # "name": "text_input", # the name of the input in node defination, could be different from the input name in diffusers - # "label": "Prompt", - # "type": "string", - # "default": "a bear sitting in a chair drinking a milkshake", - # "display": "textarea"} - # if type is not specified, it'll be a "custom" param of its own type + # "label": "Prompt", + # "type": "string", + # "default": "a bear sitting in a chair drinking a milkshake", + # "display": "textarea"} + # if type is not specified, it'll be a "custom" param of its own type # e.g. you can pass ModularNode(scheduler = {name :"scheduler"}) # it will get this spec in node defination {"scheduler": {"label": "Scheduler", "type": "scheduler", "display": "input"}} # name can be a dict, in that case, it is part of a "dict" input in mellon nodes, e.g. text_encoder= {name: {"text_encoders": "text_encoder"}} @@ -236,10 +239,10 @@ def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs): if mellon_name != inp.name: self.name_mapping[inp.name] = mellon_name continue - - if not inp.name in DEFAULT_PARAM_MAPS and not inp.required and not get_group_name(inp.name): + + if inp.name not in DEFAULT_PARAM_MAPS and not inp.required and not get_group_name(inp.name): continue - + if inp.name in DEFAULT_PARAM_MAPS: # first check if it's in the default param map, if so, directly use that param = DEFAULT_PARAM_MAPS[inp.name].copy() @@ -248,7 +251,7 @@ def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs): if inp.name not in self.name_mapping: self.name_mapping[inp.name] = param else: - # if not, check if it's in the SDXL input schema, if so, + # if not, check if it's in the SDXL input schema, if so, # 1. use the type hint to determine the type # 2. use the default param dict for the type e.g. if "steps" is a "int" type, {"steps": {"type": "int", "default": 0, "min": 0}} if inp.type_hint is not None: @@ -285,7 +288,7 @@ def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs): break if to_exclude: continue - + if get_group_name(comp.name): param = get_group_name(comp.name) if comp.name not in self.name_mapping: @@ -303,7 +306,7 @@ def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs): outputs = self.blocks.blocks[last_block_name].intermediates_outputs else: outputs = self.blocks.intermediates_outputs - + for out in outputs: param = kwargs.pop(out.name, None) if param: @@ -326,10 +329,10 @@ def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs): param = out.name # add the param dict to the outputs dict output_params[out.name] = param - + if len(kwargs) > 0: logger.warning(f"Unused kwargs: {kwargs}") - + register_dict = { "category": category, "label": label, @@ -339,7 +342,7 @@ def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs): "name_mapping": self.name_mapping, } self.register_to_config(**register_dict) - + def setup(self, components, collection=None): self.blocks.setup_loader(component_manager=components, collection=collection) self._components_manager = components @@ -347,7 +350,7 @@ def setup(self, components, collection=None): @property def mellon_config(self): return self._convert_to_mellon_config() - + def _convert_to_mellon_config(self): node = {} @@ -368,13 +371,13 @@ def _convert_to_mellon_config(self): } else: param = inp_param - + if mellon_name not in node_param: node_param[mellon_name] = param else: logger.debug(f"Input param {mellon_name} already exists in node_param, skipping {inp_name}") - + for comp_name, comp_param in self.config.component_params.items(): if comp_name in self.name_mapping: mellon_name = self.name_mapping[comp_name] @@ -388,13 +391,13 @@ def _convert_to_mellon_config(self): } else: param = comp_param - + if mellon_name not in node_param: node_param[mellon_name] = param else: logger.debug(f"Component param {comp_param} already exists in node_param, skipping {comp_name}") - + for out_name, out_param in self.config.output_params.items(): if out_name in self.name_mapping: mellon_name = self.name_mapping[out_name] @@ -408,7 +411,7 @@ def _convert_to_mellon_config(self): } else: param = out_param - + if mellon_name not in node_param: node_param[mellon_name] = param else: @@ -427,22 +430,22 @@ def save_mellon_config(self, file_path): Path: Path to the saved config file """ file_path = Path(file_path) - + # Create directory if it doesn't exist os.makedirs(file_path.parent, exist_ok=True) - + # Create a combined dictionary with module definition and name mapping config = { "module": self.mellon_config, "name_mapping": self.name_mapping } - + # Save the config to file with open(file_path, 'w', encoding='utf-8') as f: json.dump(config, f, indent=2) - + logger.info(f"Mellon config and name mapping saved to {file_path}") - + return file_path @classmethod @@ -457,16 +460,16 @@ def load_mellon_config(cls, file_path): dict: The loaded combined configuration containing 'module' and 'name_mapping' """ file_path = Path(file_path) - + if not file_path.exists(): raise FileNotFoundError(f"Config file not found: {file_path}") - + with open(file_path, 'r', encoding='utf-8') as f: config = json.load(f) - + logger.info(f"Mellon config loaded from {file_path}") - - + + return config def process_inputs(self, **kwargs): @@ -483,7 +486,7 @@ def process_inputs(self, **kwargs): if comp: params_components[comp_name] = self._components_manager.get_one(comp["model_id"]) - + params_run = {} for inp_name, inp_param in self.config.input_params.items(): logger.debug(f"input: {inp_name}") @@ -495,14 +498,14 @@ def process_inputs(self, **kwargs): inp = kwargs.pop(mellon_inp_name) if inp is not None: params_run[inp_name] = inp - + return_output_names = list(self.config.output_params.keys()) return params_components, params_run, return_output_names def execute(self, **kwargs): params_components, params_run, return_output_names = self.process_inputs(**kwargs) - + self.blocks.loader.update(**params_components) output = self.blocks.run(**params_run, output=return_output_names) return output diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py index 1fbc141ac3de..2fe15bbbee4a 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py @@ -34,11 +34,24 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: - from .modular_pipeline_presets import StableDiffusionXLAutoPipeline - from .modular_loader import StableDiffusionXLModularLoader - from .encoders import StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoVaeEncoderStep from .decoders import StableDiffusionXLAutoDecodeStep - from .modular_block_mappings import SDXL_SUPPORTED_BLOCKS, TEXT2IMAGE_BLOCKS, IMAGE2IMAGE_BLOCKS, INPAINT_BLOCKS, CONTROLNET_BLOCKS, CONTROLNET_UNION_BLOCKS, IP_ADAPTER_BLOCKS, AUTO_BLOCKS + from .encoders import ( + StableDiffusionXLAutoIPAdapterStep, + StableDiffusionXLAutoVaeEncoderStep, + StableDiffusionXLTextEncoderStep, + ) + from .modular_block_mappings import ( + AUTO_BLOCKS, + CONTROLNET_BLOCKS, + CONTROLNET_UNION_BLOCKS, + IMAGE2IMAGE_BLOCKS, + INPAINT_BLOCKS, + IP_ADAPTER_BLOCKS, + SDXL_SUPPORTED_BLOCKS, + TEXT2IMAGE_BLOCKS, + ) + from .modular_loader import StableDiffusionXLModularLoader + from .modular_pipeline_presets import StableDiffusionXLAutoPipeline else: import sys diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py index f6ff33967512..2032a57dcfcc 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py @@ -13,32 +13,27 @@ # limitations under the License. import inspect -from typing import Any, List, Optional, Tuple, Union, Dict +from typing import Any, List, Optional, Tuple, Union import PIL import torch -from collections import OrderedDict -from ...image_processor import VaeImageProcessor, PipelineImageInput -from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin -from ...models import ControlNetModel, UNet2DConditionModel, AutoencoderKL, ControlNetUnionModel -from ...utils import logging -from ...utils.torch_utils import randn_tensor, unwrap_module - -from ...pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel from ...pipelines.controlnet.multicontrolnet import MultiControlNetModel from ...schedulers import EulerDiscreteScheduler -from ...configuration_utils import FrozenDict - -from .modular_loader import StableDiffusionXLModularLoader -from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from ...utils import logging +from ...utils.torch_utils import randn_tensor, unwrap_module from ..modular_pipeline import ( AutoPipelineBlocks, - ModularLoader, PipelineBlock, PipelineState, SequentialPipelineBlocks, ) +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from .modular_loader import StableDiffusionXLModularLoader + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -237,7 +232,7 @@ def intermediates_inputs(self) -> List[str]: InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Pre-generated image embeddings for IP-Adapter. Can be generated from ip_adapter step."), InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Pre-generated negative image embeddings for IP-Adapter. Can be generated from ip_adapter step."), ] - + @property def intermediates_outputs(self) -> List[str]: return [ @@ -250,7 +245,7 @@ def intermediates_outputs(self) -> List[str]: OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], kwargs_type="guider_input_fields", description="image embeddings for IP-Adapter"), OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], kwargs_type="guider_input_fields", description="negative image embeddings for IP-Adapter"), ] - + def check_inputs(self, components, block_state): if block_state.prompt_embeds is not None and block_state.negative_prompt_embeds is not None: @@ -270,13 +265,13 @@ def check_inputs(self, components, block_state): raise ValueError( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) - + if block_state.ip_adapter_embeds is not None and not isinstance(block_state.ip_adapter_embeds, list): raise ValueError("`ip_adapter_embeds` must be a list") - + if block_state.negative_ip_adapter_embeds is not None and not isinstance(block_state.negative_ip_adapter_embeds, list): raise ValueError("`negative_ip_adapter_embeds` must be a list") - + if block_state.ip_adapter_embeds is not None and block_state.negative_ip_adapter_embeds is not None: for i, ip_adapter_embed in enumerate(block_state.ip_adapter_embeds): if ip_adapter_embed.shape != block_state.negative_ip_adapter_embeds[i].shape: @@ -298,19 +293,19 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt # duplicate text embeddings for each generation per prompt, using mps friendly method block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) block_state.prompt_embeds = block_state.prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1) - + if block_state.negative_prompt_embeds is not None: _, seq_len, _ = block_state.negative_prompt_embeds.shape block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1) - + block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1) - + if block_state.negative_pooled_prompt_embeds is not None: block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1) - + if block_state.ip_adapter_embeds is not None: for i, ip_adapter_embed in enumerate(block_state.ip_adapter_embeds): block_state.ip_adapter_embeds[i] = torch.cat([ip_adapter_embed] * block_state.num_images_per_prompt, dim=0) @@ -318,7 +313,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt if block_state.negative_ip_adapter_embeds is not None: for i, negative_ip_adapter_embed in enumerate(block_state.negative_ip_adapter_embeds): block_state.negative_ip_adapter_embeds[i] = torch.cat([negative_ip_adapter_embed] * block_state.num_images_per_prompt, dim=0) - + self.add_block_state(state, block_state) return components, state @@ -356,14 +351,14 @@ def inputs(self) -> List[InputParam]: @property def intermediates_inputs(self) -> List[str]: return [ - InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt"), + InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt"), ] @property def intermediates_outputs(self) -> List[str]: return [ - OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), - OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time"), + OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), + OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time"), OutputParam("latent_timestep", type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image generation") ] @@ -455,7 +450,7 @@ def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("scheduler", EulerDiscreteScheduler), ] - + @property def description(self) -> str: return ( @@ -473,7 +468,7 @@ def inputs(self) -> List[InputParam]: @property def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), + return [OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time")] @@ -524,7 +519,7 @@ def inputs(self) -> List[Tuple[str, Any]]: InputParam("num_images_per_prompt", default=1), InputParam("denoising_start"), InputParam( - "strength", + "strength", default=0.9999, description="Conceptually, indicates how much to transform the reference `image` (the masked portion of image for inpainting). Must be between 0 and 1. `image` " "will be used as a starting point, adding more noise to it the larger the `strength`. The number of " @@ -540,46 +535,46 @@ def intermediates_inputs(self) -> List[str]: return [ InputParam("generator"), InputParam( - "batch_size", - required=True, - type_hint=int, + "batch_size", + required=True, + type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), + ), InputParam( - "latent_timestep", - required=True, - type_hint=torch.Tensor, + "latent_timestep", + required=True, + type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step." - ), + ), InputParam( - "image_latents", - required=True, - type_hint=torch.Tensor, + "image_latents", + required=True, + type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step." - ), + ), InputParam( - "mask", - required=True, - type_hint=torch.Tensor, + "mask", + required=True, + type_hint=torch.Tensor, description="The mask for the inpainting generation. Can be generated in vae_encode step." - ), + ), InputParam( - "masked_image_latents", - type_hint=torch.Tensor, + "masked_image_latents", + type_hint=torch.Tensor, description="The masked image latents for the inpainting generation (only for inpainting-specific unet). Can be generated in vae_encode step." ), InputParam( - "dtype", - type_hint=torch.dtype, + "dtype", + type_hint=torch.dtype, description="The dtype of the model inputs" ) ] @property def intermediates_outputs(self) -> List[str]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"), - OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for inpainting generation"), - OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting generation (only for inpainting-specific unet)"), + return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"), + OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for inpainting generation"), + OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting generation (only for inpainting-specific unet)"), OutputParam("noise", type_hint=torch.Tensor, description="The noise added to the image latents, used for inpainting generation")] @@ -587,13 +582,13 @@ def intermediates_outputs(self) -> List[str]: # YiYi TODO: update the _encode_vae_image so that we can use #Coped from @staticmethod def _encode_vae_image(components, image: torch.Tensor, generator: torch.Generator): - + latents_mean = latents_std = None if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) - + dtype = image.dtype if components.vae.config.force_upcast: image = image.float() @@ -619,7 +614,7 @@ def _encode_vae_image(components, image: torch.Tensor, generator: torch.Generato else: image_latents = components.vae.config.scaling_factor * image_latents - return image_latents + return image_latents # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents adding components as first argument def prepare_latents_inpaint( @@ -737,15 +732,15 @@ def prepare_mask_latents( masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) return mask, masked_image_latents - - + + @torch.no_grad() def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype block_state.device = components._execution_device - + block_state.is_strength_max = block_state.strength == 1.0 # for non-inpainting specific unet, we do not need masked_image_latents @@ -822,9 +817,9 @@ def inputs(self) -> List[Tuple[str, Any]]: def intermediates_inputs(self) -> List[InputParam]: return [ InputParam("generator"), - InputParam("latent_timestep", required=True, type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step."), - InputParam("image_latents", required=True, type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step."), - InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."), + InputParam("latent_timestep", required=True, type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step."), + InputParam("image_latents", required=True, type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step."), + InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."), InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs")] @property @@ -886,14 +881,14 @@ def intermediates_inputs(self) -> List[InputParam]: return [ InputParam("generator"), InputParam( - "batch_size", - required=True, - type_hint=int, + "batch_size", + required=True, + type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), + ), InputParam( - "dtype", - type_hint=torch.dtype, + "dtype", + type_hint=torch.dtype, description="The dtype of the model inputs" ) ] @@ -902,8 +897,8 @@ def intermediates_inputs(self) -> List[InputParam]: def intermediates_outputs(self) -> List[OutputParam]: return [ OutputParam( - "latents", - type_hint=torch.Tensor, + "latents", + type_hint=torch.Tensor, description="The initial latents to use for the denoising process" ) ] @@ -980,7 +975,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): model_name = "stable-diffusion-xl" - + @property def expected_configs(self) -> List[ConfigSpec]: return [ConfigSpec("requires_aesthetics_score", False),] @@ -1008,15 +1003,15 @@ def inputs(self) -> List[Tuple[str, Any]]: @property def intermediates_inputs(self) -> List[InputParam]: return [ - InputParam("latents", required=True, type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."), + InputParam("latents", required=True, type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."), InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor, description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step."), InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."), ] @property def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The time ids to condition the denoising process"), - OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process"), + return [OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The time ids to condition the denoising process"), + OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process"), OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids with self -> components @@ -1183,29 +1178,29 @@ def inputs(self) -> List[Tuple[str, Any]]: def intermediates_inputs(self) -> List[InputParam]: return [ InputParam( - "latents", - required=True, - type_hint=torch.Tensor, + "latents", + required=True, + type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." - ), + ), InputParam( - "pooled_prompt_embeds", - required=True, - type_hint=torch.Tensor, + "pooled_prompt_embeds", + required=True, + type_hint=torch.Tensor, description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step." ), InputParam( - "batch_size", - required=True, - type_hint=int, + "batch_size", + required=True, + type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." ), ] @property def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The time ids to condition the denoising process"), - OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process"), + return [OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The time ids to condition the denoising process"), + OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process"), OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids with self -> components @@ -1344,26 +1339,26 @@ def inputs(self) -> List[Tuple[str, Any]]: def intermediates_inputs(self) -> List[str]: return [ InputParam( - "latents", - required=True, - type_hint=torch.Tensor, + "latents", + required=True, + type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." ), InputParam( - "batch_size", - required=True, - type_hint=int, + "batch_size", + required=True, + type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." ), InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, + "timesteps", + required=True, + type_hint=torch.Tensor, description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." ), InputParam( - "crops_coords", - type_hint=Optional[Tuple[int]], + "crops_coords", + type_hint=Optional[Tuple[int]], description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." ), ] @@ -1395,12 +1390,12 @@ def prepare_control_image( device, dtype, crops_coords=None, - ): + ): if crops_coords is not None: image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) else: image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) - + image_batch_size = image.shape[0] if image_batch_size == 1: repeat_by = batch_size @@ -1416,9 +1411,9 @@ def prepare_control_image( @torch.no_grad() def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - + block_state = self.get_block_state(state) - + # (1) prepare controlnet inputs block_state.device = components._execution_device block_state.height, block_state.width = block_state.latents.shape[-2:] @@ -1446,14 +1441,14 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt block_state.controlnet_conditioning_scale = [block_state.controlnet_conditioning_scale] * len(controlnet.nets) # (1.3) - # global_pool_conditions + # global_pool_conditions block_state.global_pool_conditions = ( controlnet.config.global_pool_conditions if isinstance(controlnet, ControlNetModel) else controlnet.nets[0].config.global_pool_conditions ) # (1.4) - # guess_mode + # guess_mode block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions # (1.5) @@ -1501,12 +1496,12 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt for s, e in zip(block_state.control_guidance_start, block_state.control_guidance_end) ] block_state.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) - + block_state.controlnet_cond = block_state.control_image block_state.conditioning_scale = block_state.controlnet_conditioning_scale - + self.add_block_state(state, block_state) return components, state @@ -1542,32 +1537,32 @@ def inputs(self) -> List[Tuple[str, Any]]: def intermediates_inputs(self) -> List[InputParam]: return [ InputParam( - "latents", - required=True, - type_hint=torch.Tensor, + "latents", + required=True, + type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Used to determine the shape of the control images. Can be generated in prepare_latent step." ), InputParam( - "batch_size", - required=True, - type_hint=int, + "batch_size", + required=True, + type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." ), InputParam( - "dtype", - required=True, - type_hint=torch.dtype, + "dtype", + required=True, + type_hint=torch.dtype, description="The dtype of model tensor inputs. Can be generated in input step." - ), + ), InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, + "timesteps", + required=True, + type_hint=torch.Tensor, description="The timesteps to use for the denoising process. Needed to determine `controlnet_keep`. Can be generated in set_timesteps step." ), InputParam( - "crops_coords", - type_hint=Optional[Tuple[int]], + "crops_coords", + type_hint=Optional[Tuple[int]], description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." ), ] @@ -1599,12 +1594,12 @@ def prepare_control_image( device, dtype, crops_coords=None, - ): + ): if crops_coords is not None: image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) else: image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) - + image_batch_size = image.shape[0] if image_batch_size == 1: repeat_by = batch_size @@ -1618,7 +1613,7 @@ def prepare_control_image( @torch.no_grad() def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - + block_state = self.get_block_state(state) controlnet = unwrap_module(components.controlnet) @@ -1651,7 +1646,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt if len(block_state.control_image) != len(block_state.control_mode): raise ValueError("Expected len(control_image) == len(control_type)") - # control_type + # control_type block_state.num_control_type = controlnet.config.num_control_type block_state.control_type = [0 for _ in range(block_state.num_control_type)] for control_idx in block_state.control_mode: @@ -1676,7 +1671,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt crops_coords=block_state.crops_coords, ) block_state.height, block_state.width = block_state.control_image[idx].shape[-2:] - + # controlnet_keep block_state.controlnet_keep = [] for i in range(len(block_state.timesteps)): @@ -1687,7 +1682,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt block_state.control_type_idx = block_state.control_mode block_state.controlnet_cond = block_state.control_image block_state.conditioning_scale = block_state.controlnet_conditioning_scale - + self.add_block_state(state, block_state) return components, state @@ -1698,7 +1693,7 @@ class StableDiffusionXLControlNetAutoInput(AutoPipelineBlocks): block_classes = [StableDiffusionXLControlNetUnionInputStep, StableDiffusionXLControlNetInputStep] block_names = ["controlnet_union", "controlnet"] block_trigger_inputs = ["control_mode", "control_image"] - + @property def description(self): return "Controlnet Input step that prepare the controlnet input.\n" + \ diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py index ca848e20984f..3a4e141775f5 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py @@ -12,29 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect -from typing import Any, List, Optional, Tuple, Union, Dict +from typing import Any, List, Tuple, Union +import numpy as np import PIL import torch -import numpy as np -from collections import OrderedDict -from ...image_processor import VaeImageProcessor, PipelineImageInput +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor from ...utils import logging - -from ...pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput -from ...configuration_utils import FrozenDict - -from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam from ..modular_pipeline import ( AutoPipelineBlocks, PipelineBlock, PipelineState, SequentialPipelineBlocks, ) +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -44,15 +40,15 @@ class StableDiffusionXLDecodeStep(PipelineBlock): model_name = "stable-diffusion-xl" - + @property def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("vae", AutoencoderKL), ComponentSpec( - "image_processor", - VaeImageProcessor, - config=FrozenDict({"vae_scale_factor": 8}), + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), default_creation_method="from_config"), ] @@ -160,10 +156,10 @@ def description(self) -> str: def inputs(self) -> List[Tuple[str, Any]]: return [ InputParam("image", required=True), - InputParam("mask_image", required=True), + InputParam("mask_image", required=True), InputParam("padding_mask_crop"), ] - + @property def intermediates_inputs(self) -> List[str]: return [ diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py index 3a8bca74b5a0..564665110006 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py @@ -13,28 +13,25 @@ # limitations under the License. import inspect -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple import torch -from tqdm.auto import tqdm from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance from ...models import ControlNetModel, UNet2DConditionModel from ...schedulers import EulerDiscreteScheduler from ...utils import logging -from ...utils.torch_utils import unwrap_module - -from ...guiders import ClassifierFreeGuidance -from .modular_loader import StableDiffusionXLModularLoader -from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam from ..modular_pipeline import ( - PipelineBlock, - PipelineState, AutoPipelineBlocks, - LoopSequentialPipelineBlocks, BlockState, + LoopSequentialPipelineBlocks, + PipelineBlock, + PipelineState, ) -from dataclasses import asdict +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_loader import StableDiffusionXLModularLoader + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -61,9 +58,9 @@ def description(self) -> str: def intermediates_inputs(self) -> List[str]: return [ InputParam( - "latents", - required=True, - type_hint=torch.Tensor, + "latents", + required=True, + type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." ), ] @@ -96,19 +93,19 @@ def description(self) -> str: def intermediates_inputs(self) -> List[str]: return [ InputParam( - "latents", - required=True, - type_hint=torch.Tensor, + "latents", + required=True, + type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." ), InputParam( - "mask", - type_hint=Optional[torch.Tensor], + "mask", + type_hint=Optional[torch.Tensor], description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." ), InputParam( - "masked_image_latents", - type_hint=Optional[torch.Tensor], + "masked_image_latents", + type_hint=Optional[torch.Tensor], description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." ), ] @@ -133,7 +130,7 @@ def check_inputs(components, block_state): f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" " `components.unet` or your `mask_image` or `image` input." ) - + @torch.no_grad() def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): @@ -155,9 +152,9 @@ class StableDiffusionXLLoopDenoiser(PipelineBlock): def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), default_creation_method="from_config"), ComponentSpec("unet", UNet2DConditionModel), ] @@ -178,18 +175,18 @@ def inputs(self) -> List[Tuple[str, Any]]: def intermediates_inputs(self) -> List[str]: return [ InputParam( - "num_inference_steps", - required=True, - type_hint=int, + "num_inference_steps", + required=True, + type_hint=int, description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." ), InputParam( - "timestep_cond", - type_hint=Optional[torch.Tensor], + "timestep_cond", + type_hint=Optional[torch.Tensor], description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." ), InputParam( - kwargs_type="guider_input_fields", + kwargs_type="guider_input_fields", description=( "All conditional model inputs that need to be prepared with guider. " "It should contain prompt_embeds/negative_prompt_embeds, " @@ -202,10 +199,10 @@ def intermediates_inputs(self) -> List[str]: ] - + @torch.no_grad() def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int) -> PipelineState: - + # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) guider_input_fields ={ @@ -231,7 +228,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, block_state: Bloc cond_kwargs = guider_state_batch.as_dict() cond_kwargs = {k:v for k,v in cond_kwargs.items() if k in guider_input_fields} prompt_embeds = cond_kwargs.pop("prompt_embeds") - + # Predict the noise residual # store the noise_pred in guider_state_batch so that we can apply guidance across all batches guider_state_batch.noise_pred = components.unet( @@ -259,9 +256,9 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock): def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), default_creation_method="from_config"), ComponentSpec("unet", UNet2DConditionModel), ComponentSpec("controlnet", ControlNetModel), @@ -281,18 +278,18 @@ def inputs(self) -> List[Tuple[str, Any]]: def intermediates_inputs(self) -> List[str]: return [ InputParam( - "controlnet_cond", + "controlnet_cond", required=True, type_hint=torch.Tensor, description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step." ), InputParam( - "conditioning_scale", + "conditioning_scale", type_hint=float, description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." ), InputParam( - "guess_mode", + "guess_mode", required=True, type_hint=bool, description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." @@ -304,18 +301,18 @@ def intermediates_inputs(self) -> List[str]: description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step." ), InputParam( - "timestep_cond", - type_hint=Optional[torch.Tensor], + "timestep_cond", + type_hint=Optional[torch.Tensor], description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" ), InputParam( - "num_inference_steps", - required=True, - type_hint=int, + "num_inference_steps", + required=True, + type_hint=int, description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." ), InputParam( - kwargs_type="guider_input_fields", + kwargs_type="guider_input_fields", description=( "All conditional model inputs that need to be prepared with guider. " "It should contain prompt_embeds/negative_prompt_embeds, " @@ -326,7 +323,7 @@ def intermediates_inputs(self) -> List[str]: ) ), InputParam( - kwargs_type="controlnet_kwargs", + kwargs_type="controlnet_kwargs", description=( "additional kwargs for controlnet (e.g. control_type_idx and control_type from the controlnet union input step )" "please add `kwargs_type=controlnet_kwargs` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" @@ -369,14 +366,14 @@ def __call__(self, components: StableDiffusionXLModularLoader, block_state: Bloc if isinstance(controlnet_cond_scale, list): controlnet_cond_scale = controlnet_cond_scale[0] block_state.cond_scale = controlnet_cond_scale * block_state.controlnet_keep[i] - + # default controlnet output/unet input for guess mode + conditional path block_state.down_block_res_samples_zeros = None block_state.mid_block_res_sample_zeros = None - + # guided denoiser step components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) - + # Prepare mini‐batches according to guidance method and `guider_input_fields` # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds. # e.g. for CFG, we prepare two batches: one for uncond, one for cond @@ -387,7 +384,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, block_state: Bloc # run the denoiser for each guidance batch for guider_state_batch in guider_state: components.guider.prepare_models(components.unet) - + # Prepare additional conditionings added_cond_kwargs = { "text_embeds": guider_state_batch.text_embeds, @@ -395,7 +392,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, block_state: Bloc } if hasattr(guider_state_batch, "image_embeds") and guider_state_batch.image_embeds is not None: added_cond_kwargs["image_embeds"] = guider_state_batch.image_embeds - + # Prepare controlnet additional conditionings controlnet_added_cond_kwargs = { "text_embeds": guider_state_batch.text_embeds, @@ -418,13 +415,13 @@ def __call__(self, components: StableDiffusionXLModularLoader, block_state: Bloc return_dict=False, **extra_controlnet_kwargs, ) - + # assign it to block_state so it will be available for the uncond guidance batch if block_state.down_block_res_samples_zeros is None: block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in down_block_res_samples] if block_state.mid_block_res_sample_zeros is None: block_state.mid_block_res_sample_zeros = torch.zeros_like(mid_block_res_sample) - + # Predict the noise # store the noise_pred in guider_state_batch so we can apply guidance across all batches guider_state_batch.noise_pred = components.unet( @@ -439,7 +436,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, block_state: Bloc return_dict=False, )[0] components.guider.cleanup_models(components.unet) - + # Perform guidance block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state) @@ -475,7 +472,7 @@ def intermediates_inputs(self) -> List[str]: @property def intermediates_outputs(self) -> List[OutputParam]: return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - + #YiYi TODO: move this out of here @staticmethod def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): @@ -499,7 +496,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, block_state: Bloc # Perform scheduler step using the predicted output block_state.latents_dtype = block_state.latents.dtype block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **block_state.scheduler_step_kwargs, return_dict=False)[0] - + if block_state.latents.dtype != block_state.latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 @@ -534,24 +531,24 @@ def intermediates_inputs(self) -> List[str]: return [ InputParam("generator"), InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, + "timesteps", + required=True, + type_hint=torch.Tensor, description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." ), InputParam( - "mask", - type_hint=Optional[torch.Tensor], + "mask", + type_hint=Optional[torch.Tensor], description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." ), InputParam( - "noise", - type_hint=Optional[torch.Tensor], + "noise", + type_hint=Optional[torch.Tensor], description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." ), InputParam( - "image_latents", - type_hint=Optional[torch.Tensor], + "image_latents", + type_hint=Optional[torch.Tensor], description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." ), ] @@ -559,7 +556,7 @@ def intermediates_inputs(self) -> List[str]: @property def intermediates_outputs(self) -> List[OutputParam]: return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - + @staticmethod def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): @@ -570,7 +567,7 @@ def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): extra_kwargs[key] = value return extra_kwargs - + def check_inputs(self, components, block_state): if components.num_channels_unet == 4: if block_state.image_latents is None: @@ -582,9 +579,9 @@ def check_inputs(self, components, block_state): @torch.no_grad() def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): - + self.check_inputs(components, block_state) - + # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) @@ -592,12 +589,12 @@ def __call__(self, components: StableDiffusionXLModularLoader, block_state: Bloc # Perform scheduler step using the predicted output block_state.latents_dtype = block_state.latents.dtype block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **block_state.scheduler_step_kwargs, return_dict=False)[0] - + if block_state.latents.dtype != block_state.latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 block_state.latents = block_state.latents.to(block_state.latents_dtype) - + # adjust latent for inpainting if components.num_channels_unet == 4: block_state.init_latents_proper = block_state.image_latents @@ -629,32 +626,32 @@ def description(self) -> str: def loop_expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), default_creation_method="from_config"), ComponentSpec("scheduler", EulerDiscreteScheduler), ComponentSpec("unet", UNet2DConditionModel), ] - + @property def loop_intermediates_inputs(self) -> List[InputParam]: return [ InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, + "timesteps", + required=True, + type_hint=torch.Tensor, description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." ), InputParam( - "num_inference_steps", - required=True, - type_hint=int, + "num_inference_steps", + required=True, + type_hint=int, description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." ), ] - - + + @torch.no_grad() def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) @@ -782,619 +779,4 @@ def description(self) -> str: "This is a auto pipeline block that works for text2img, img2img and inpainting tasks. And can be used with or without controlnet." " - `StableDiffusionXLDenoiseStep` (denoise) is used when no controlnet_cond is provided (work for text2img, img2img and inpainting tasks)." " - `StableDiffusionXLControlNetDenoiseStep` (controlnet_denoise) is used when controlnet_cond is provided (work for text2img, img2img and inpainting tasks)." - ) - - - - - - - -# YiYi Notes: alternatively, this is you can just write the denoise loop using a pipeline block, easier but not composible -# class StableDiffusionXLDenoiseStep(PipelineBlock): - -# model_name = "stable-diffusion-xl" - -# @property -# def expected_components(self) -> List[ComponentSpec]: -# return [ -# ComponentSpec( -# "guider", -# ClassifierFreeGuidance, -# config=FrozenDict({"guidance_scale": 7.5}), -# default_creation_method="from_config"), -# ComponentSpec("scheduler", EulerDiscreteScheduler), -# ComponentSpec("unet", UNet2DConditionModel), -# ] - -# @property -# def description(self) -> str: -# return ( -# "Step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process" -# ) - -# @property -# def inputs(self) -> List[Tuple[str, Any]]: -# return [ -# InputParam("cross_attention_kwargs"), -# InputParam("generator"), -# InputParam("eta", default=0.0), -# InputParam("num_images_per_prompt", default=1), -# ] - -# @property -# def intermediates_inputs(self) -> List[str]: -# return [ -# InputParam( -# "latents", -# required=True, -# type_hint=torch.Tensor, -# description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." -# ), -# InputParam( -# "batch_size", -# required=True, -# type_hint=int, -# description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." -# ), -# InputParam( -# "timesteps", -# required=True, -# type_hint=torch.Tensor, -# description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." -# ), -# InputParam( -# "num_inference_steps", -# required=True, -# type_hint=int, -# description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." -# ), -# InputParam( -# "pooled_prompt_embeds", -# required=True, -# type_hint=torch.Tensor, -# description="The pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "negative_pooled_prompt_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " -# ), -# InputParam( -# "add_time_ids", -# required=True, -# type_hint=torch.Tensor, -# description="The time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." -# ), -# InputParam( -# "negative_add_time_ids", -# type_hint=Optional[torch.Tensor], -# description="The negative time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." -# ), -# InputParam( -# "prompt_embeds", -# required=True, -# type_hint=torch.Tensor, -# description="The prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "negative_prompt_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " -# ), -# InputParam( -# "timestep_cond", -# type_hint=Optional[torch.Tensor], -# description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." -# ), -# InputParam( -# "mask", -# type_hint=Optional[torch.Tensor], -# description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "masked_image_latents", -# type_hint=Optional[torch.Tensor], -# description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "noise", -# type_hint=Optional[torch.Tensor], -# description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." -# ), -# InputParam( -# "image_latents", -# type_hint=Optional[torch.Tensor], -# description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "ip_adapter_embeds", -# type_hint=Optional[torch.Tensor], -# description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." -# ), -# InputParam( -# "negative_ip_adapter_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." -# ), -# ] - -# @property -# def intermediates_outputs(self) -> List[OutputParam]: -# return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - - -# @staticmethod -# def check_inputs(components, block_state): - -# num_channels_unet = components.unet.config.in_channels -# if num_channels_unet == 9: -# # default case for runwayml/stable-diffusion-inpainting -# if block_state.mask is None or block_state.masked_image_latents is None: -# raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") -# num_channels_latents = block_state.latents.shape[1] -# num_channels_mask = block_state.mask.shape[1] -# num_channels_masked_image = block_state.masked_image_latents.shape[1] -# if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: -# raise ValueError( -# f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" -# f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" -# f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" -# f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" -# " `components.unet` or your `mask_image` or `image` input." -# ) - -# # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components -# @staticmethod -# def prepare_extra_step_kwargs(components, generator, eta): -# # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature -# # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. -# # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 -# # and should be between [0, 1] - -# accepts_eta = "eta" in set(inspect.signature(components.scheduler.step).parameters.keys()) -# extra_step_kwargs = {} -# if accepts_eta: -# extra_step_kwargs["eta"] = eta - -# # check if the scheduler accepts generator -# accepts_generator = "generator" in set(inspect.signature(components.scheduler.step).parameters.keys()) -# if accepts_generator: -# extra_step_kwargs["generator"] = generator -# return extra_step_kwargs - -# @torch.no_grad() -# def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - -# block_state = self.get_block_state(state) -# self.check_inputs(components, block_state) - -# block_state.num_channels_unet = components.unet.config.in_channels -# block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False -# if block_state.disable_guidance: -# components.guider.disable() -# else: -# components.guider.enable() - -# # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline -# block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) -# block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) - -# components.guider.set_input_fields( -# prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), -# add_time_ids=("add_time_ids", "negative_add_time_ids"), -# pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), -# ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), -# ) - -# with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: -# for i, t in enumerate(block_state.timesteps): -# components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) -# guider_data = components.guider.prepare_inputs(block_state) - -# block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) - -# # Prepare for inpainting -# if block_state.num_channels_unet == 9: -# block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) - -# for batch in guider_data: -# components.guider.prepare_models(components.unet) - -# # Prepare additional conditionings -# batch.added_cond_kwargs = { -# "text_embeds": batch.pooled_prompt_embeds, -# "time_ids": batch.add_time_ids, -# } -# if batch.ip_adapter_embeds is not None: -# batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds - -# # Predict the noise residual -# batch.noise_pred = components.unet( -# block_state.scaled_latents, -# t, -# encoder_hidden_states=batch.prompt_embeds, -# timestep_cond=block_state.timestep_cond, -# cross_attention_kwargs=block_state.cross_attention_kwargs, -# added_cond_kwargs=batch.added_cond_kwargs, -# return_dict=False, -# )[0] -# components.guider.cleanup_models(components.unet) - -# # Perform guidance -# block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_data) - -# # Perform scheduler step using the predicted output -# block_state.latents_dtype = block_state.latents.dtype -# block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] - -# if block_state.latents.dtype != block_state.latents_dtype: -# if torch.backends.mps.is_available(): -# # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 -# block_state.latents = block_state.latents.to(block_state.latents_dtype) - -# if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: -# block_state.init_latents_proper = block_state.image_latents -# if i < len(block_state.timesteps) - 1: -# block_state.noise_timestep = block_state.timesteps[i + 1] -# block_state.init_latents_proper = components.scheduler.add_noise( -# block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) -# ) - -# block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents - -# if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): -# progress_bar.update() - -# self.add_block_state(state, block_state) - -# return components, state - - - -# class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): - -# model_name = "stable-diffusion-xl" - -# @property -# def expected_components(self) -> List[ComponentSpec]: -# return [ -# ComponentSpec( -# "guider", -# ClassifierFreeGuidance, -# config=FrozenDict({"guidance_scale": 7.5}), -# default_creation_method="from_config"), -# ComponentSpec("scheduler", EulerDiscreteScheduler), -# ComponentSpec("unet", UNet2DConditionModel), -# ComponentSpec("controlnet", ControlNetModel), -# ] - -# @property -# def description(self) -> str: -# return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" - -# @property -# def inputs(self) -> List[Tuple[str, Any]]: -# return [ -# InputParam("num_images_per_prompt", default=1), -# InputParam("cross_attention_kwargs"), -# InputParam("generator"), -# InputParam("eta", default=0.0), -# InputParam("controlnet_conditioning_scale", type_hint=float, default=1.0), # can expect either input or intermediate input, (intermediate input if both are passed) -# ] - -# @property -# def intermediates_inputs(self) -> List[str]: -# return [ -# InputParam( -# "controlnet_cond", -# required=True, -# type_hint=torch.Tensor, -# description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "control_guidance_start", -# required=True, -# type_hint=float, -# description="The control guidance start value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "control_guidance_end", -# required=True, -# type_hint=float, -# description="The control guidance end value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "conditioning_scale", -# type_hint=float, -# description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "guess_mode", -# required=True, -# type_hint=bool, -# description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "controlnet_keep", -# required=True, -# type_hint=List[float], -# description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "latents", -# required=True, -# type_hint=torch.Tensor, -# description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." -# ), -# InputParam( -# "batch_size", -# required=True, -# type_hint=int, -# description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." -# ), -# InputParam( -# "timesteps", -# required=True, -# type_hint=torch.Tensor, -# description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." -# ), -# InputParam( -# "prompt_embeds", -# required=True, -# type_hint=torch.Tensor, -# description="The prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "negative_prompt_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "add_time_ids", -# required=True, -# type_hint=torch.Tensor, -# description="The time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." -# ), -# InputParam( -# "negative_add_time_ids", -# type_hint=Optional[torch.Tensor], -# description="The negative time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." -# ), -# InputParam( -# "pooled_prompt_embeds", -# required=True, -# type_hint=torch.Tensor, -# description="The pooled prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "negative_pooled_prompt_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "timestep_cond", -# type_hint=Optional[torch.Tensor], -# description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" -# ), -# InputParam( -# "mask", -# type_hint=Optional[torch.Tensor], -# description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "masked_image_latents", -# type_hint=Optional[torch.Tensor], -# description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "noise", -# type_hint=Optional[torch.Tensor], -# description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." -# ), -# InputParam( -# "image_latents", -# type_hint=Optional[torch.Tensor], -# description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "crops_coords", -# type_hint=Optional[Tuple[int]], -# description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." -# ), -# InputParam( -# "ip_adapter_embeds", -# type_hint=Optional[torch.Tensor], -# description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." -# ), -# InputParam( -# "negative_ip_adapter_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." -# ), -# InputParam( -# "num_inference_steps", -# required=True, -# type_hint=int, -# description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." -# ), -# InputParam(kwargs_type="controlnet_kwargs", description="additional kwargs for controlnet") -# ] - -# @property -# def intermediates_outputs(self) -> List[OutputParam]: -# return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - -# @staticmethod -# def check_inputs(components, block_state): - -# num_channels_unet = components.unet.config.in_channels -# if num_channels_unet == 9: -# # default case for runwayml/stable-diffusion-inpainting -# if block_state.mask is None or block_state.masked_image_latents is None: -# raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") -# num_channels_latents = block_state.latents.shape[1] -# num_channels_mask = block_state.mask.shape[1] -# num_channels_masked_image = block_state.masked_image_latents.shape[1] -# if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: -# raise ValueError( -# f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" -# f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" -# f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" -# f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" -# " `components.unet` or your `mask_image` or `image` input." -# ) -# @staticmethod -# def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): - -# accepted_kwargs = set(inspect.signature(func).parameters.keys()) -# extra_kwargs = {} -# for key, value in kwargs.items(): -# if key in accepted_kwargs and key not in exclude_kwargs: -# extra_kwargs[key] = value - -# return extra_kwargs - - -# @torch.no_grad() -# def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - -# block_state = self.get_block_state(state) -# self.check_inputs(components, block_state) -# block_state.device = components._execution_device -# print(f" block_state: {block_state}") - -# controlnet = unwrap_module(components.controlnet) - -# # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline -# block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) -# block_state.extra_controlnet_kwargs = self.prepare_extra_kwargs(controlnet.forward, exclude_kwargs=["controlnet_cond", "conditioning_scale", "guess_mode"], **block_state.controlnet_kwargs) - -# block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) - -# # (1) setup guider -# # disable for LCMs -# block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False -# if block_state.disable_guidance: -# components.guider.disable() -# else: -# components.guider.enable() -# components.guider.set_input_fields( -# prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), -# add_time_ids=("add_time_ids", "negative_add_time_ids"), -# pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), -# ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), -# ) - -# # (5) Denoise loop -# with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: -# for i, t in enumerate(block_state.timesteps): - -# # prepare latent input for unet -# block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) -# # adjust latent input for inpainting -# block_state.num_channels_unet = components.unet.config.in_channels -# if block_state.num_channels_unet == 9: -# block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) - - -# # cond_scale (controlnet input) -# if isinstance(block_state.controlnet_keep[i], list): -# block_state.cond_scale = [c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i])] -# else: -# block_state.controlnet_cond_scale = block_state.conditioning_scale -# if isinstance(block_state.controlnet_cond_scale, list): -# block_state.controlnet_cond_scale = block_state.controlnet_cond_scale[0] -# block_state.cond_scale = block_state.controlnet_cond_scale * block_state.controlnet_keep[i] - -# # default controlnet output/unet input for guess mode + conditional path -# block_state.down_block_res_samples_zeros = None -# block_state.mid_block_res_sample_zeros = None - -# # guided denoiser step -# components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) -# guider_state = components.guider.prepare_inputs(block_state) - -# for guider_state_batch in guider_state: -# components.guider.prepare_models(components.unet) - -# # Prepare additional conditionings -# guider_state_batch.added_cond_kwargs = { -# "text_embeds": guider_state_batch.pooled_prompt_embeds, -# "time_ids": guider_state_batch.add_time_ids, -# } -# if guider_state_batch.ip_adapter_embeds is not None: -# guider_state_batch.added_cond_kwargs["image_embeds"] = guider_state_batch.ip_adapter_embeds - -# # Prepare controlnet additional conditionings -# guider_state_batch.controlnet_added_cond_kwargs = { -# "text_embeds": guider_state_batch.pooled_prompt_embeds, -# "time_ids": guider_state_batch.add_time_ids, -# } - -# if block_state.guess_mode and not components.guider.is_conditional: -# # guider always run uncond batch first, so these tensors should be set already -# guider_state_batch.down_block_res_samples = block_state.down_block_res_samples_zeros -# guider_state_batch.mid_block_res_sample = block_state.mid_block_res_sample_zeros -# else: -# guider_state_batch.down_block_res_samples, guider_state_batch.mid_block_res_sample = components.controlnet( -# block_state.scaled_latents, -# t, -# encoder_hidden_states=guider_state_batch.prompt_embeds, -# controlnet_cond=block_state.controlnet_cond, -# conditioning_scale=block_state.conditioning_scale, -# guess_mode=block_state.guess_mode, -# added_cond_kwargs=guider_state_batch.controlnet_added_cond_kwargs, -# return_dict=False, -# **block_state.extra_controlnet_kwargs, -# ) - -# if block_state.down_block_res_samples_zeros is None: -# block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in guider_state_batch.down_block_res_samples] -# if block_state.mid_block_res_sample_zeros is None: -# block_state.mid_block_res_sample_zeros = torch.zeros_like(guider_state_batch.mid_block_res_sample) - - - -# guider_state_batch.noise_pred = components.unet( -# block_state.scaled_latents, -# t, -# encoder_hidden_states=guider_state_batch.prompt_embeds, -# timestep_cond=block_state.timestep_cond, -# cross_attention_kwargs=block_state.cross_attention_kwargs, -# added_cond_kwargs=guider_state_batch.added_cond_kwargs, -# down_block_additional_residuals=guider_state_batch.down_block_res_samples, -# mid_block_additional_residual=guider_state_batch.mid_block_res_sample, -# return_dict=False, -# )[0] -# components.guider.cleanup_models(components.unet) - -# # Perform guidance -# block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_state) - -# # Perform scheduler step using the predicted output -# block_state.latents_dtype = block_state.latents.dtype -# block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] - -# if block_state.latents.dtype != block_state.latents_dtype: -# if torch.backends.mps.is_available(): -# # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 -# block_state.latents = block_state.latents.to(block_state.latents_dtype) - -# # adjust latent for inpainting -# if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: -# block_state.init_latents_proper = block_state.image_latents -# if i < len(block_state.timesteps) - 1: -# block_state.noise_timestep = block_state.timesteps[i + 1] -# block_state.init_latents_proper = components.scheduler.add_noise( -# block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) -# ) - -# block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents - -# if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): -# progress_bar.update() - -# self.add_block_state(state, block_state) - -# return components, state \ No newline at end of file + ) \ No newline at end of file diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py index ca4efe2c4a7f..a563ffbbbe86 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py @@ -12,44 +12,33 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect -from typing import Any, List, Optional, Tuple, Union, Dict +from typing import List, Optional, Tuple -import PIL import torch -from collections import OrderedDict - -from ...image_processor import VaeImageProcessor, PipelineImageInput -from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin -from ...models import ControlNetModel, ImageProjection, UNet2DConditionModel, AutoencoderKL, ControlNetUnionModel -from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor -from ...models.lora import adjust_lora_scale_text_encoder -from ...utils import ( - USE_PEFT_BACKEND, - logging, - scale_lora_layers, - unscale_lora_layers, -) -from ...utils.torch_utils import randn_tensor, unwrap_module -from ...pipelines.controlnet.multicontrolnet import MultiControlNetModel -from ...configuration_utils import FrozenDict - from transformers import ( - CLIPTextModel, CLIPImageProcessor, + CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection, ) -from ...schedulers import EulerDiscreteScheduler +from ...configuration_utils import FrozenDict from ...guiders import ClassifierFreeGuidance - +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ModularIPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...utils import ( + USE_PEFT_BACKEND, + logging, + scale_lora_layers, + unscale_lora_layers, +) +from ..modular_pipeline import AutoPipelineBlocks, PipelineBlock, PipelineState +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam from .modular_loader import StableDiffusionXLModularLoader -from ..modular_pipeline import PipelineBlock, PipelineState, AutoPipelineBlocks, SequentialPipelineBlocks -from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam, ConfigSpec -import numpy as np logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -71,7 +60,7 @@ def retrieve_latents( class StableDiffusionXLIPAdapterStep(PipelineBlock): model_name = "stable-diffusion-xl" - + @property def description(self) -> str: return ( @@ -79,7 +68,7 @@ def description(self) -> str: " See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)" " for more details" ) - + @property def expected_components(self) -> List[ComponentSpec]: return [ @@ -87,8 +76,8 @@ def expected_components(self) -> List[ComponentSpec]: ComponentSpec("feature_extractor", CLIPImageProcessor, config=FrozenDict({"size": 224, "crop_size": 224}), default_creation_method="from_config"), ComponentSpec("unet", UNet2DConditionModel), ComponentSpec( - "guider", - ClassifierFreeGuidance, + "guider", + ClassifierFreeGuidance, config=FrozenDict({"guidance_scale": 7.5}), default_creation_method="from_config"), ] @@ -97,8 +86,8 @@ def expected_components(self) -> List[ComponentSpec]: def inputs(self) -> List[InputParam]: return [ InputParam( - "ip_adapter_image", - PipelineImageInput, + "ip_adapter_image", + PipelineImageInput, required=True, description="The image(s) to be used as ip adapter" ) @@ -111,7 +100,7 @@ def intermediates_outputs(self) -> List[OutputParam]: OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"), OutputParam("negative_ip_adapter_embeds", type_hint=torch.Tensor, description="Negative IP adapter image embeddings") ] - + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image with self -> components @staticmethod def encode_image(components, image, device, num_images_per_prompt, output_hidden_states=None): @@ -137,7 +126,7 @@ def encode_image(components, image, device, num_images_per_prompt, output_hidden uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds - + # modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, prepare_unconditional_embeds @@ -219,7 +208,7 @@ def description(self) -> str: return( "Text Encoder step that generate text_embeddings to guide the image generation" ) - + @property def expected_components(self) -> List[ComponentSpec]: return [ @@ -228,9 +217,9 @@ def expected_components(self) -> List[ComponentSpec]: ComponentSpec("tokenizer", CLIPTokenizer), ComponentSpec("tokenizer_2", CLIPTokenizer), ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), default_creation_method="from_config"), ] @@ -546,7 +535,7 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock): model_name = "stable-diffusion-xl" - + @property def description(self) -> str: return ( @@ -558,9 +547,9 @@ def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("vae", AutoencoderKL), ComponentSpec( - "image_processor", - VaeImageProcessor, - config=FrozenDict({"vae_scale_factor": 8}), + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), default_creation_method="from_config"), ] @@ -576,7 +565,7 @@ def inputs(self) -> List[InputParam]: def intermediates_inputs(self) -> List[InputParam]: return [ InputParam("generator"), - InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), + InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), InputParam("preprocess_kwargs", type_hint=Optional[dict], description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]")] @property @@ -586,13 +575,13 @@ def intermediates_outputs(self) -> List[OutputParam]: # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components # YiYi TODO: update the _encode_vae_image so that we can use #Coped from def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator): - + latents_mean = latents_std = None if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) - + dtype = image.dtype if components.vae.config.force_upcast: image = image.float() @@ -618,8 +607,8 @@ def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Ge else: image_latents = components.vae.config.scaling_factor * image_latents - return image_latents - + return image_latents + @torch.no_grad() @@ -628,7 +617,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt block_state.preprocess_kwargs = block_state.preprocess_kwargs or {} block_state.device = components._execution_device block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype - + block_state.image = components.image_processor.preprocess(block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs) block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype) @@ -651,23 +640,23 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): model_name = "stable-diffusion-xl" - + @property def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("vae", AutoencoderKL), ComponentSpec( - "image_processor", - VaeImageProcessor, - config=FrozenDict({"vae_scale_factor": 8}), + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), default_creation_method="from_config"), ComponentSpec( - "mask_processor", - VaeImageProcessor, + "mask_processor", + VaeImageProcessor, config=FrozenDict({"do_normalize": False, "vae_scale_factor": 8, "do_binarize": True, "do_convert_grayscale": True}), default_creation_method="from_config"), ] - + @property def description(self) -> str: @@ -694,21 +683,21 @@ def intermediates_inputs(self) -> List[InputParam]: @property def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representation of the input image"), - OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"), - OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)"), + return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representation of the input image"), + OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"), + OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)"), OutputParam("crops_coords", type_hint=Optional[Tuple[int, int]], description="The crop coordinates to use for the preprocess/postprocess of the image and mask")] # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components # YiYi TODO: update the _encode_vae_image so that we can use #Coped from def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator): - + latents_mean = latents_std = None if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) - + dtype = image.dtype if components.vae.config.force_upcast: image = image.float() @@ -734,7 +723,7 @@ def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Ge else: image_latents = components.vae.config.scaling_factor * image_latents - return image_latents + return image_latents # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents # do not accept do_classifier_free_guidance @@ -784,8 +773,8 @@ def prepare_mask_latents( masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) return mask, masked_image_latents - - + + @torch.no_grad() def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: @@ -801,7 +790,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt else: block_state.crops_coords = None block_state.resize_mode = "default" - + block_state.image = components.image_processor.preprocess(block_state.image, height=block_state.height, width=block_state.width, crops_coords=block_state.crops_coords, resize_mode=block_state.resize_mode) block_state.image = block_state.image.to(dtype=torch.float32) @@ -834,7 +823,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt # auto blocks (YiYi TODO: maybe move all the auto blocks to a separate file) # Encode -class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks): +class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks): block_classes = [StableDiffusionXLInpaintVaeEncoderStep, StableDiffusionXLVaeEncoderStep] block_names = ["inpaint", "img2img"] block_trigger_inputs = ["mask_image", "image"] diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_block_mappings.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_block_mappings.py index 4ffd685df044..9440d72319f3 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_block_mappings.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_block_mappings.py @@ -13,44 +13,40 @@ # limitations under the License. from ..modular_pipeline_utils import InsertableOrderedDict +from .before_denoise import ( + StableDiffusionXLAutoBeforeDenoiseStep, + StableDiffusionXLControlNetInputStep, + StableDiffusionXLControlNetUnionInputStep, + StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, + StableDiffusionXLImg2ImgPrepareLatentsStep, + StableDiffusionXLImg2ImgSetTimestepsStep, + StableDiffusionXLInpaintPrepareLatentsStep, + StableDiffusionXLInputStep, + StableDiffusionXLPrepareAdditionalConditioningStep, + StableDiffusionXLPrepareLatentsStep, + StableDiffusionXLSetTimestepsStep, +) +from .decoders import StableDiffusionXLAutoDecodeStep, StableDiffusionXLDecodeStep, StableDiffusionXLInpaintDecodeStep # Import all the necessary block classes from .denoise import ( StableDiffusionXLAutoDenoiseStep, StableDiffusionXLControlNetDenoiseStep, StableDiffusionXLDenoiseLoop, - StableDiffusionXLInpaintDenoiseLoop -) -from .before_denoise import ( - StableDiffusionXLAutoBeforeDenoiseStep, - StableDiffusionXLInputStep, - StableDiffusionXLSetTimestepsStep, - StableDiffusionXLPrepareLatentsStep, - StableDiffusionXLPrepareAdditionalConditioningStep, - StableDiffusionXLImg2ImgSetTimestepsStep, - StableDiffusionXLImg2ImgPrepareLatentsStep, - StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, - StableDiffusionXLInpaintPrepareLatentsStep, - StableDiffusionXLControlNetInputStep, - StableDiffusionXLControlNetUnionInputStep + StableDiffusionXLInpaintDenoiseLoop, ) from .encoders import ( - StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep, - StableDiffusionXLVaeEncoderStep, StableDiffusionXLInpaintVaeEncoderStep, - StableDiffusionXLIPAdapterStep -) -from .decoders import ( - StableDiffusionXLDecodeStep, - StableDiffusionXLInpaintDecodeStep, - StableDiffusionXLAutoDecodeStep + StableDiffusionXLIPAdapterStep, + StableDiffusionXLTextEncoderStep, + StableDiffusionXLVaeEncoderStep, ) # YiYi notes: comment out for now, work on this later -# block mapping +# block mapping TEXT2IMAGE_BLOCKS = InsertableOrderedDict([ ("text_encoder", StableDiffusionXLTextEncoderStep), ("input", StableDiffusionXLInputStep), diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py index 4af942af64e6..0f567513c57d 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py @@ -12,20 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Tuple, Union, Dict +from typing import List, Optional, Tuple, Union + +import numpy as np import PIL import torch -import numpy as np -from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin from ...image_processor import PipelineImageInput +from ...loaders import ModularIPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin from ...pipelines.pipeline_utils import StableDiffusionMixin from ...pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput from ...utils import logging - from ..modular_pipeline import ModularLoader from ..modular_pipeline_utils import InputParam, OutputParam + logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py index 637c7ac306d7..981f4d7e033a 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py @@ -12,14 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Tuple, Union, Dict from ...utils import logging from ..modular_pipeline import SequentialPipelineBlocks - -from .denoise import StableDiffusionXLAutoDenoiseStep from .before_denoise import StableDiffusionXLAutoBeforeDenoiseStep from .decoders import StableDiffusionXLAutoDecodeStep -from .encoders import StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep +from .denoise import StableDiffusionXLAutoDenoiseStep +from .encoders import ( + StableDiffusionXLAutoIPAdapterStep, + StableDiffusionXLAutoVaeEncoderStep, + StableDiffusionXLTextEncoderStep, +) + logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/utils/dynamic_modules_utils.py b/src/diffusers/utils/dynamic_modules_utils.py index 47aae7198450..8eb99038c172 100644 --- a/src/diffusers/utils/dynamic_modules_utils.py +++ b/src/diffusers/utils/dynamic_modules_utils.py @@ -15,12 +15,12 @@ """Utilities to dynamically load objects from the Hub.""" import importlib -import signal import inspect import json import os import re import shutil +import signal import sys import threading from pathlib import Path @@ -531,4 +531,4 @@ def get_class_from_dynamic_module( revision=revision, local_files_only=local_files_only, ) - return get_class_in_module(class_name, final_module) \ No newline at end of file + return get_class_in_module(class_name, final_module) From 74b908b7e2c9ce65ac9f82a13a07aac0ae122bf4 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 25 Jun 2025 12:04:52 +0200 Subject: [PATCH 092/170] style --- src/diffusers/__init__.py | 34 +- src/diffusers/commands/custom_blocks.py | 14 +- src/diffusers/guiders/__init__.py | 10 +- .../guiders/adaptive_projected_guidance.py | 7 +- src/diffusers/guiders/auto_guidance.py | 15 +- .../guiders/classifier_free_guidance.py | 28 +- .../classifier_free_zero_star_guidance.py | 11 +- src/diffusers/guiders/guider_utils.py | 68 +- src/diffusers/guiders/skip_layer_guidance.py | 23 +- .../guiders/smoothed_energy_guidance.py | 37 +- .../tangential_classifier_free_guidance.py | 11 +- src/diffusers/hooks/layer_skip.py | 18 +- .../hooks/smoothed_energy_guidance_utils.py | 32 +- src/diffusers/loaders/ip_adapter.py | 13 +- .../modular_pipelines/components_manager.py | 161 +++-- .../modular_pipelines/modular_pipeline.py | 467 ++++++------ .../modular_pipeline_utils.py | 123 ++-- src/diffusers/modular_pipelines/node_utils.py | 302 +++++--- .../stable_diffusion_xl/__init__.py | 21 +- .../stable_diffusion_xl/before_denoise.py | 684 ++++++++++++------ .../stable_diffusion_xl/decoders.py | 89 ++- .../stable_diffusion_xl/denoise.py | 200 ++--- .../stable_diffusion_xl/encoders.py | 205 ++++-- .../modular_block_mappings.py | 115 +-- .../stable_diffusion_xl/modular_loader.py | 310 ++++++-- .../modular_pipeline_presets.py | 29 +- .../pipelines/pipeline_loading_utils.py | 2 + src/diffusers/pipelines/pipeline_utils.py | 11 +- 28 files changed, 1915 insertions(+), 1125 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 18d90be500b7..7bb8469c36d1 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -38,8 +38,8 @@ "hooks": [], "loaders": ["FromOriginalModelMixin"], "models": [], - "pipelines": [], "modular_pipelines": [], + "pipelines": [], "quantizers.quantization_config": [], "schedulers": [], "utils": [ @@ -147,8 +147,8 @@ [ "FasterCacheConfig", "HookRegistry", - "PyramidAttentionBroadcastConfig", "LayerSkipConfig", + "PyramidAttentionBroadcastConfig", "SmoothedEnergyGuidanceConfig", "apply_faster_cache", "apply_layer_skip", @@ -235,6 +235,15 @@ "WanVACETransformer3DModel", ] ) + _import_structure["modular_pipelines"].extend( + [ + "ComponentsManager", + "ComponentSpec", + "ModularLoader", + "ModularPipeline", + "ModularPipelineBlocks", + ] + ) _import_structure["optimization"] = [ "get_constant_schedule", "get_constant_schedule_with_warmup", @@ -266,15 +275,6 @@ "StableDiffusionMixin", ] ) - _import_structure["modular_pipelines"].extend( - [ - "ModularLoader", - "ModularPipeline", - "ModularPipelineBlocks", - "ComponentSpec", - "ComponentsManager", - ] - ) _import_structure["quantizers"] = ["DiffusersQuantizer"] _import_structure["schedulers"].extend( [ @@ -356,6 +356,12 @@ ] else: + _import_structure["modular_pipelines"].extend( + [ + "StableDiffusionXLAutoPipeline", + "StableDiffusionXLModularLoader", + ] + ) _import_structure["pipelines"].extend( [ "AllegroPipeline", @@ -565,12 +571,6 @@ "WuerstchenPriorPipeline", ] ) - _import_structure["modular_pipelines"].extend( - [ - "StableDiffusionXLAutoPipeline", - "StableDiffusionXLModularLoader", - ] - ) try: diff --git a/src/diffusers/commands/custom_blocks.py b/src/diffusers/commands/custom_blocks.py index f532e8b775fd..07fca44678ba 100644 --- a/src/diffusers/commands/custom_blocks.py +++ b/src/diffusers/commands/custom_blocks.py @@ -30,6 +30,7 @@ EXPECTED_PARENT_CLASSES = ["PipelineBlock"] CONFIG = "config.json" + def conversion_command_factory(args: Namespace): return CustomBlocksCommand(args.block_module_name, args.block_class_name) @@ -45,7 +46,10 @@ def register_subcommand(parser: ArgumentParser): help="Module filename in which the custom block will be implemented.", ) conversion_parser.add_argument( - "--block_class_name", type=str, default=None, help="Name of the custom block. If provided None, we will try to infer it." + "--block_class_name", + type=str, + default=None, + help="Name of the custom block. If provided None, we will try to infer it.", ) conversion_parser.set_defaults(func=conversion_command_factory) @@ -71,7 +75,7 @@ def run(self): f"Found classes: {classes_found} will be using {classes_found[0]}. " "If this needs to be changed, re-run the command specifying `block_class_name`." ) - child_class, parent_class = out[0][0], out[0][1] + child_class, parent_class = out[0][0], out[0][1] # dynamically get the custom block and initialize it to call `save_pretrained` in the current directory. # the user is responsible for running it, so I guess that is safe? @@ -107,10 +111,7 @@ def _get_class_names(self, file_path): continue # extract all base names for this class - base_names = [ - bname for b in node.bases - if (bname := self._get_base_name(b)) is not None - ] + base_names = [bname for b in node.bases if (bname := self._get_base_name(b)) is not None] # for each allowed base that appears in the class's bases, emit a tuple for allowed in EXPECTED_PARENT_CLASSES: @@ -131,4 +132,3 @@ def _create_automap(self, parent_class, child_class): module = str(self.block_module_name).replace(".py", "").rsplit(".", 1)[-1] auto_map = {f"{parent_class}": f"{module}.{child_class}"} return {"auto_map": auto_map} - diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py index 3c1ee293382d..0c5198a17b20 100644 --- a/src/diffusers/guiders/__init__.py +++ b/src/diffusers/guiders/__init__.py @@ -26,4 +26,12 @@ from .smoothed_energy_guidance import SmoothedEnergyGuidance from .tangential_classifier_free_guidance import TangentialClassifierFreeGuidance - GuiderType = Union[AdaptiveProjectedGuidance, AutoGuidance, ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, SkipLayerGuidance, SmoothedEnergyGuidance, TangentialClassifierFreeGuidance] + GuiderType = Union[ + AdaptiveProjectedGuidance, + AutoGuidance, + ClassifierFreeGuidance, + ClassifierFreeZeroStarGuidance, + SkipLayerGuidance, + SmoothedEnergyGuidance, + TangentialClassifierFreeGuidance, + ] diff --git a/src/diffusers/guiders/adaptive_projected_guidance.py b/src/diffusers/guiders/adaptive_projected_guidance.py index f1a6096c4d6a..32c61823e908 100644 --- a/src/diffusers/guiders/adaptive_projected_guidance.py +++ b/src/diffusers/guiders/adaptive_projected_guidance.py @@ -27,7 +27,7 @@ class AdaptiveProjectedGuidance(BaseGuidance): """ Adaptive Projected Guidance (APG): https://huggingface.co/papers/2410.02416 - + Args: guidance_scale (`float`, defaults to `7.5`): The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text @@ -74,8 +74,9 @@ def __init__( self.use_original_formulation = use_original_formulation self.momentum_buffer = None - def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: - + def prepare_inputs( + self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None + ) -> List["BlockState"]: if input_fields is None: input_fields = self._input_fields diff --git a/src/diffusers/guiders/auto_guidance.py b/src/diffusers/guiders/auto_guidance.py index 83120c20ceca..9891de5e4d6b 100644 --- a/src/diffusers/guiders/auto_guidance.py +++ b/src/diffusers/guiders/auto_guidance.py @@ -29,7 +29,7 @@ class AutoGuidance(BaseGuidance): """ AutoGuidance: https://huggingface.co/papers/2406.02507 - + Args: guidance_scale (`float`, defaults to `7.5`): The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text @@ -86,7 +86,9 @@ def __init__( ) if auto_guidance_layers is not None and auto_guidance_config is not None: raise ValueError("Only one of `auto_guidance_layers` or `auto_guidance_config` can be provided.") - if (dropout is None and auto_guidance_layers is not None) or (dropout is not None and auto_guidance_layers is None): + if (dropout is None and auto_guidance_layers is not None) or ( + dropout is not None and auto_guidance_layers is None + ): raise ValueError("`dropout` must be provided if `auto_guidance_layers` is provided.") if auto_guidance_layers is not None: @@ -96,7 +98,9 @@ def __init__( raise ValueError( f"Expected `auto_guidance_layers` to be an int or a list of ints, but got {type(auto_guidance_layers)}." ) - auto_guidance_config = [LayerSkipConfig(layer, fqn="auto", dropout=dropout) for layer in auto_guidance_layers] + auto_guidance_config = [ + LayerSkipConfig(layer, fqn="auto", dropout=dropout) for layer in auto_guidance_layers + ] if isinstance(auto_guidance_config, LayerSkipConfig): auto_guidance_config = [auto_guidance_config] @@ -121,8 +125,9 @@ def cleanup_models(self, denoiser: torch.nn.Module) -> None: registry = HookRegistry.check_if_exists_or_initialize(denoiser) registry.remove_hook(name, recurse=True) - def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: - + def prepare_inputs( + self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None + ) -> List["BlockState"]: if input_fields is None: input_fields = self._input_fields diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py index faeba0971157..f914deb56ca3 100644 --- a/src/diffusers/guiders/classifier_free_guidance.py +++ b/src/diffusers/guiders/classifier_free_guidance.py @@ -27,25 +27,25 @@ class ClassifierFreeGuidance(BaseGuidance): """ Classifier-free guidance (CFG): https://huggingface.co/papers/2207.12598 - + CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during - inference. This allows the model to tradeoff between generation quality and sample diversity. - The original paper proposes scaling and shifting the conditional distribution based on the difference between - conditional and unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)] - + inference. This allows the model to tradeoff between generation quality and sample diversity. The original paper + proposes scaling and shifting the conditional distribution based on the difference between conditional and + unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)] + Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)] - + The intution behind the original formulation can be thought of as moving the conditional distribution estimates further away from the unconditional distribution estimates, while the diffusers-native implementation can be thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.) - + The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time. - + Args: guidance_scale (`float`, defaults to `7.5`): The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text @@ -68,7 +68,12 @@ class ClassifierFreeGuidance(BaseGuidance): _input_predictions = ["pred_cond", "pred_uncond"] def __init__( - self, guidance_scale: float = 7.5, guidance_rescale: float = 0.0, use_original_formulation: bool = False, start: float = 0.0, stop: float = 1.0 + self, + guidance_scale: float = 7.5, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, ): super().__init__(start, stop) @@ -76,8 +81,9 @@ def __init__( self.guidance_rescale = guidance_rescale self.use_original_formulation = use_original_formulation - def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: - + def prepare_inputs( + self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None + ) -> List["BlockState"]: if input_fields is None: input_fields = self._input_fields diff --git a/src/diffusers/guiders/classifier_free_zero_star_guidance.py b/src/diffusers/guiders/classifier_free_zero_star_guidance.py index b4dee9295ab6..1c70b45a5ed7 100644 --- a/src/diffusers/guiders/classifier_free_zero_star_guidance.py +++ b/src/diffusers/guiders/classifier_free_zero_star_guidance.py @@ -27,14 +27,14 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance): """ Classifier-free Zero* (CFG-Zero*): https://huggingface.co/papers/2503.18886 - + This is an implementation of the Classifier-Free Zero* guidance technique, which is a variant of classifier-free guidance. It proposes zero initialization of the noise predictions for the first few steps of the diffusion process, and also introduces an optimal rescaling factor for the noise predictions, which can help in improving the quality of generated images. - + The authors of the paper suggest setting zero initialization in the first 4% of the inference steps. - + Args: guidance_scale (`float`, defaults to `7.5`): The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text @@ -74,8 +74,9 @@ def __init__( self.guidance_rescale = guidance_rescale self.use_original_formulation = use_original_formulation - def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: - + def prepare_inputs( + self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None + ) -> List["BlockState"]: if input_fields is None: input_fields = self._input_fields diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index 87109eb048ed..12731ed43530 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -43,13 +43,9 @@ def __init__(self, start: float = 0.0, stop: float = 1.0): self._enabled = True if not (0.0 <= start < 1.0): - raise ValueError( - f"Expected `start` to be between 0.0 and 1.0, but got {start}." - ) + raise ValueError(f"Expected `start` to be between 0.0 and 1.0, but got {start}.") if not (start <= stop <= 1.0): - raise ValueError( - f"Expected `stop` to be between {start} and 1.0, but got {stop}." - ) + raise ValueError(f"Expected `stop` to be between {start} and 1.0, but got {stop}.") if self._input_predictions is None or not isinstance(self._input_predictions, list): raise ValueError( @@ -70,23 +66,21 @@ def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTen def set_input_fields(self, **kwargs: Dict[str, Union[str, Tuple[str, str]]]) -> None: """ - Set the input fields for the guidance technique. The input fields are used to specify the names of the - returned attributes containing the prepared data after `prepare_inputs` is called. The prepared data is - obtained from the values of the provided keyword arguments to this method. + Set the input fields for the guidance technique. The input fields are used to specify the names of the returned + attributes containing the prepared data after `prepare_inputs` is called. The prepared data is obtained from + the values of the provided keyword arguments to this method. Args: **kwargs (`Dict[str, Union[str, Tuple[str, str]]]`): - A dictionary where the keys are the names of the fields that will be used to store the data once - it is prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, - which is used to look up the required data provided for preparation. + A dictionary where the keys are the names of the fields that will be used to store the data once it is + prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used + to look up the required data provided for preparation. - If a string is provided, it will be used as the conditional data (or unconditional if used with - a guidance method that requires it). If a tuple of length 2 is provided, the first element must - be the conditional data identifier and the second element must be the unconditional data identifier - or None. + If a string is provided, it will be used as the conditional data (or unconditional if used with a + guidance method that requires it). If a tuple of length 2 is provided, the first element must be the + conditional data identifier and the second element must be the unconditional data identifier or None. Example: - ``` data = {"prompt_embeds": , "negative_prompt_embeds": , "latents": } @@ -98,7 +92,9 @@ def set_input_fields(self, **kwargs: Dict[str, Union[str, Tuple[str, str]]]) -> """ for key, value in kwargs.items(): is_string = isinstance(value, str) - is_tuple_of_str_with_len_2 = isinstance(value, tuple) and len(value) == 2 and all(isinstance(v, str) for v in value) + is_tuple_of_str_with_len_2 = ( + isinstance(value, tuple) and len(value) == 2 and all(isinstance(v, str) for v in value) + ) if not (is_string or is_tuple_of_str_with_len_2): raise ValueError( f"Expected `set_input_fields` to be called with a string or a tuple of string with length 2, but got {type(value)} for key {key}." @@ -114,8 +110,8 @@ def prepare_models(self, denoiser: torch.nn.Module) -> None: def cleanup_models(self, denoiser: torch.nn.Module) -> None: """ - Cleans up the models for the guidance technique after a given batch of data. This method should be overridden in - subclasses to implement specific model cleanup logic. It is useful for removing any hooks or other stateful + Cleans up the models for the guidance technique after a given batch of data. This method should be overridden + in subclasses to implement specific model cleanup logic. It is useful for removing any hooks or other stateful modifications made during `prepare_models`. """ pass @@ -149,32 +145,39 @@ def num_conditions(self) -> int: raise NotImplementedError("BaseGuidance::num_conditions must be implemented in subclasses.") @classmethod - def _prepare_batch(cls, input_fields: Dict[str, Union[str, Tuple[str, str]]], data: "BlockState", tuple_index: int, identifier: str) -> "BlockState": + def _prepare_batch( + cls, + input_fields: Dict[str, Union[str, Tuple[str, str]]], + data: "BlockState", + tuple_index: int, + identifier: str, + ) -> "BlockState": """ - Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of - the `BaseGuidance` class. It prepares the batch based on the provided tuple index. + Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of the + `BaseGuidance` class. It prepares the batch based on the provided tuple index. Args: input_fields (`Dict[str, Union[str, Tuple[str, str]]]`): - A dictionary where the keys are the names of the fields that will be used to store the data once - it is prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, - which is used to look up the required data provided for preparation. - If a string is provided, it will be used as the conditional data (or unconditional if used with - a guidance method that requires it). If a tuple of length 2 is provided, the first element must - be the conditional data identifier and the second element must be the unconditional data identifier - or None. + A dictionary where the keys are the names of the fields that will be used to store the data once it is + prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used + to look up the required data provided for preparation. If a string is provided, it will be used as the + conditional data (or unconditional if used with a guidance method that requires it). If a tuple of + length 2 is provided, the first element must be the conditional data identifier and the second element + must be the unconditional data identifier or None. data (`BlockState`): The input data to be prepared. tuple_index (`int`): The index to use when accessing input fields that are tuples. - + Returns: `BlockState`: The prepared batch of data. """ from ..modular_pipelines.modular_pipeline import BlockState if input_fields is None: - raise ValueError("Input fields cannot be None. Please pass `input_fields` to `prepare_inputs` or call `set_input_fields` before preparing inputs.") + raise ValueError( + "Input fields cannot be None. Please pass `input_fields` to `prepare_inputs` or call `set_input_fields` before preparing inputs." + ) data_batch = {} for key, value in input_fields.items(): try: @@ -196,6 +199,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Args: noise_cfg (`torch.Tensor`): The predicted noise tensor for the guided diffusion process. diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py index ffe00ea7db33..f0e7f035420b 100644 --- a/src/diffusers/guiders/skip_layer_guidance.py +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -29,28 +29,28 @@ class SkipLayerGuidance(BaseGuidance): """ Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5 - + Spatio-Temporal Guidance (STG): https://huggingface.co/papers/2411.18664 - + SLG was introduced by StabilityAI for improving structure and anotomy coherence in generated images. It works by skipping the forward pass of specified transformer blocks during the denoising process on an additional conditional batch of data, apart from the conditional and unconditional batches already used in CFG ([~guiders.classifier_free_guidance.ClassifierFreeGuidance]), and then scaling and shifting the CFG predictions based on the difference between conditional without skipping and conditional with skipping predictions. - + The intution behind SLG can be thought of as moving the CFG predicted distribution estimates further away from worse versions of the conditional distribution estimates (because skipping layers is equivalent to using a worse version of the model for the conditional prediction). - + STG is an improvement and follow-up work combining ideas from SLG, PAG and similar techniques for improving generation quality in video diffusion models. - + Additional reading: - [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507) - + The values for `skip_layer_guidance_scale`, `skip_layer_guidance_start`, and `skip_layer_guidance_stop` are defaulted to the recommendations by StabilityAI for Stable Diffusion 3.5 Medium. - + Args: guidance_scale (`float`, defaults to `7.5`): The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text @@ -157,8 +157,9 @@ def cleanup_models(self, denoiser: torch.nn.Module) -> None: for hook_name in self._skip_layer_hook_names: registry.remove_hook(hook_name, recurse=True) - def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: - + def prepare_inputs( + self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None + ) -> List["BlockState"]: if input_fields is None: input_fields = self._input_fields @@ -167,7 +168,9 @@ def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Un input_predictions = ["pred_cond"] elif self.num_conditions == 2: tuple_indices = [0, 1] - input_predictions = ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"] + input_predictions = ( + ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"] + ) else: tuple_indices = [0, 1, 0] input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"] diff --git a/src/diffusers/guiders/smoothed_energy_guidance.py b/src/diffusers/guiders/smoothed_energy_guidance.py index ab21b6d9526d..a96a5e3e04da 100644 --- a/src/diffusers/guiders/smoothed_energy_guidance.py +++ b/src/diffusers/guiders/smoothed_energy_guidance.py @@ -30,12 +30,12 @@ class SmoothedEnergyGuidance(BaseGuidance): """ Smoothed Energy Guidance (SEG): https://huggingface.co/papers/2408.00760 - SEG is only supported as an experimental prototype feature for now, so the implementation may be modified - in the future without warning or guarantee of reproducibility. This implementation assumes: + SEG is only supported as an experimental prototype feature for now, so the implementation may be modified in the + future without warning or guarantee of reproducibility. This implementation assumes: - Generated images are square (height == width) - - The model does not combine different modalities together (e.g., text and image latent streams are - not combined together such as Flux) - + - The model does not combine different modalities together (e.g., text and image latent streams are not combined + together such as Flux) + Args: guidance_scale (`float`, defaults to `7.5`): The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text @@ -54,12 +54,12 @@ class SmoothedEnergyGuidance(BaseGuidance): seg_guidance_stop (`float`, defaults to `1.0`): The fraction of the total number of denoising steps after which smoothed energy guidance stops. seg_guidance_layers (`int` or `List[int]`, *optional*): - The layer indices to apply smoothed energy guidance to. Can be a single integer or a list of integers. If not - provided, `seg_guidance_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion - 3.5 Medium. + The layer indices to apply smoothed energy guidance to. Can be a single integer or a list of integers. If + not provided, `seg_guidance_config` must be provided. The recommended values are `[7, 8, 9]` for Stable + Diffusion 3.5 Medium. seg_guidance_config (`SmoothedEnergyGuidanceConfig` or `List[SmoothedEnergyGuidanceConfig]`, *optional*): - The configuration for the smoothed energy layer guidance. Can be a single `SmoothedEnergyGuidanceConfig` or a list of - `SmoothedEnergyGuidanceConfig`. If not provided, `seg_guidance_layers` must be provided. + The configuration for the smoothed energy layer guidance. Can be a single `SmoothedEnergyGuidanceConfig` or + a list of `SmoothedEnergyGuidanceConfig`. If not provided, `seg_guidance_layers` must be provided. guidance_rescale (`float`, defaults to `0.0`): The rescale factor applied to the noise predictions. This is used to improve image quality and fix overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are @@ -103,13 +103,9 @@ def __init__( self.use_original_formulation = use_original_formulation if not (0.0 <= seg_guidance_start < 1.0): - raise ValueError( - f"Expected `seg_guidance_start` to be between 0.0 and 1.0, but got {seg_guidance_start}." - ) + raise ValueError(f"Expected `seg_guidance_start` to be between 0.0 and 1.0, but got {seg_guidance_start}.") if not (seg_guidance_start <= seg_guidance_stop <= 1.0): - raise ValueError( - f"Expected `seg_guidance_stop` to be between 0.0 and 1.0, but got {seg_guidance_stop}." - ) + raise ValueError(f"Expected `seg_guidance_stop` to be between 0.0 and 1.0, but got {seg_guidance_stop}.") if seg_guidance_layers is None and seg_guidance_config is None: raise ValueError( @@ -150,8 +146,9 @@ def cleanup_models(self, denoiser: torch.nn.Module): for hook_name in self._seg_layer_hook_names: registry.remove_hook(hook_name, recurse=True) - def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: - + def prepare_inputs( + self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None + ) -> List["BlockState"]: if input_fields is None: input_fields = self._input_fields @@ -160,7 +157,9 @@ def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Un input_predictions = ["pred_cond"] elif self.num_conditions == 2: tuple_indices = [0, 1] - input_predictions = ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_seg"] + input_predictions = ( + ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_seg"] + ) else: tuple_indices = [0, 1, 0] input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"] diff --git a/src/diffusers/guiders/tangential_classifier_free_guidance.py b/src/diffusers/guiders/tangential_classifier_free_guidance.py index fdcdaf8dcb3a..17244f3788d9 100644 --- a/src/diffusers/guiders/tangential_classifier_free_guidance.py +++ b/src/diffusers/guiders/tangential_classifier_free_guidance.py @@ -27,7 +27,7 @@ class TangentialClassifierFreeGuidance(BaseGuidance): """ Tangential Classifier Free Guidance (TCFG): https://huggingface.co/papers/2503.18137 - + Args: guidance_scale (`float`, defaults to `7.5`): The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text @@ -63,8 +63,9 @@ def __init__( self.guidance_rescale = guidance_rescale self.use_original_formulation = use_original_formulation - def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: - + def prepare_inputs( + self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None + ) -> List["BlockState"]: if input_fields is None: input_fields = self._input_fields @@ -118,7 +119,9 @@ def _is_tcfg_enabled(self) -> bool: return is_within_range and not is_close -def normalized_guidance(pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guidance_scale: float, use_original_formulation: bool = False) -> torch.Tensor: +def normalized_guidance( + pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guidance_scale: float, use_original_formulation: bool = False +) -> torch.Tensor: cond_dtype = pred_cond.dtype preds = torch.stack([pred_cond, pred_uncond], dim=1).float() preds = preds.flatten(2) diff --git a/src/diffusers/hooks/layer_skip.py b/src/diffusers/hooks/layer_skip.py index 6b847271c97b..a581ce7712df 100644 --- a/src/diffusers/hooks/layer_skip.py +++ b/src/diffusers/hooks/layer_skip.py @@ -41,15 +41,15 @@ class LayerSkipConfig: r""" Configuration for skipping internal transformer blocks when executing a transformer model. - + Args: indices (`List[int]`): The indices of the layer to skip. This is typically the first layer in the transformer block. fqn (`str`, defaults to `"auto"`): The fully qualified name identifying the stack of transformer blocks. Typically, this is `transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`. - For automatic detection, set this to `"auto"`. - "auto" only works on DiT models. For UNet models, you must provide the correct fqn. + For automatic detection, set this to `"auto"`. "auto" only works on DiT models. For UNet models, you must + provide the correct fqn. skip_attention (`bool`, defaults to `True`): Whether to skip attention blocks. skip_ff (`bool`, defaults to `True`): @@ -149,20 +149,22 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): output = torch.nn.functional.dropout(output, p=self.dropout) return output + def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None: r""" Apply layer skipping to internal layers of a transformer. - + Args: module (`torch.nn.Module`): The transformer model to which the layer skip hook should be applied. config (`LayerSkipConfig`): The configuration for the layer skip hook. - + Example: - + ```python >>> from diffusers import apply_layer_skip_hook, CogVideoXTransformer3DModel, LayerSkipConfig + >>> transformer = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) >>> config = LayerSkipConfig(layer_index=[10, 20], fqn="transformer_blocks") >>> apply_layer_skip_hook(transformer, config) @@ -177,7 +179,9 @@ def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, nam if config.skip_attention and config.skip_attention_scores: raise ValueError("Cannot set both `skip_attention` and `skip_attention_scores` to True. Please choose one.") if not math.isclose(config.dropout, 1.0) and config.skip_attention_scores: - raise ValueError("Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0.") + raise ValueError( + "Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0." + ) if config.fqn == "auto": for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS: diff --git a/src/diffusers/hooks/smoothed_energy_guidance_utils.py b/src/diffusers/hooks/smoothed_energy_guidance_utils.py index 353ce7289444..cbbb3eec15b0 100644 --- a/src/diffusers/hooks/smoothed_energy_guidance_utils.py +++ b/src/diffusers/hooks/smoothed_energy_guidance_utils.py @@ -20,7 +20,7 @@ import torch.nn.functional as F from ..utils import get_logger -from ._common import _ATTENTION_CLASSES, _get_submodule_from_fqn +from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _get_submodule_from_fqn from .hooks import HookRegistry, ModelHook @@ -33,18 +33,18 @@ class SmoothedEnergyGuidanceConfig: r""" Configuration for skipping internal transformer blocks when executing a transformer model. - + Args: indices (`List[int]`): The indices of the layer to skip. This is typically the first layer in the transformer block. fqn (`str`, defaults to `"auto"`): The fully qualified name identifying the stack of transformer blocks. Typically, this is `transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`. - For automatic detection, set this to `"auto"`. - "auto" only works on DiT models. For UNet models, you must provide the correct fqn. + For automatic detection, set this to `"auto"`. "auto" only works on DiT models. For UNet models, you must + provide the correct fqn. _query_proj_identifiers (`List[str]`, defaults to `None`): - The identifiers for the query projection layers. Typically, these are `to_q`, `query`, or `q_proj`. - If `None`, `to_q` is used by default. + The identifiers for the query projection layers. Typically, these are `to_q`, `query`, or `q_proj`. If + `None`, `to_q` is used by default. """ indices: List[int] @@ -65,7 +65,9 @@ def post_forward(self, module: torch.nn.Module, output: torch.Tensor) -> torch.T return smoothed_output -def _apply_smoothed_energy_guidance_hook(module: torch.nn.Module, config: SmoothedEnergyGuidanceConfig, blur_sigma: float, name: Optional[str] = None) -> None: +def _apply_smoothed_energy_guidance_hook( + module: torch.nn.Module, config: SmoothedEnergyGuidanceConfig, blur_sigma: float, name: Optional[str] = None +) -> None: name = name or _SMOOTHED_ENERGY_GUIDANCE_HOOK if config.fqn == "auto": @@ -114,14 +116,14 @@ def _apply_smoothed_energy_guidance_hook(module: torch.nn.Module, config: Smooth # Modified from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L71 def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma_threshold_inf: float) -> torch.Tensor: """ - This implementation assumes that the input query is for visual (image/videos) tokens to apply the 2D gaussian - blur. However, some models use joint text-visual token attention for which this may not be suitable. Additionally, - this implementation also assumes that the visual tokens come from a square image/video. In practice, despite - these assumptions, applying the 2D square gaussian blur on the query projections generates reasonable results - for Smoothed Energy Guidance. - - SEG is only supported as an experimental prototype feature for now, so the implementation may be modified - in the future without warning or guarantee of reproducibility. + This implementation assumes that the input query is for visual (image/videos) tokens to apply the 2D gaussian blur. + However, some models use joint text-visual token attention for which this may not be suitable. Additionally, this + implementation also assumes that the visual tokens come from a square image/video. In practice, despite these + assumptions, applying the 2D square gaussian blur on the query projections generates reasonable results for + Smoothed Energy Guidance. + + SEG is only supported as an experimental prototype feature for now, so the implementation may be modified in the + future without warning or guarantee of reproducibility. """ assert query.ndim == 3 diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index 58043e6b2322..e05d53687a24 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -353,6 +353,7 @@ def unload_ip_adapter(self): ) self.unet.set_attn_processor(attn_procs) + class ModularIPAdapterMixin: """Mixin for handling IP Adapters.""" @@ -491,15 +492,6 @@ def load_ip_adapter( state_dicts.append(state_dict) - # create feature extractor if it has not been registered to the pipeline yet - if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None: - # FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224 - default_clip_size = 224 - clip_image_size = ( - self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size - ) - feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size) - unet_name = getattr(self, "unet_name", "unet") unet = getattr(self, unet_name) unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) @@ -558,8 +550,7 @@ def set_ip_adapter_scale(self, scale): ): if len(scale_configs) != len(attn_processor.scale): raise ValueError( - f"Cannot assign {len(scale_configs)} scale_configs to " - f"{len(attn_processor.scale)} IP-Adapter." + f"Cannot assign {len(scale_configs)} scale_configs to {len(attn_processor.scale)} IP-Adapter." ) elif len(scale_configs) == 1: scale_configs = scale_configs * len(attn_processor.scale) diff --git a/src/diffusers/modular_pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py index 3f22fa7115be..8f5c04d8a94d 100644 --- a/src/diffusers/modular_pipelines/components_manager.py +++ b/src/diffusers/modular_pipelines/components_manager.py @@ -228,16 +228,14 @@ def search_best_candidate(module_sizes, min_memory_offload): return hooks_to_offload - class ComponentsManager: def __init__(self): self.components = OrderedDict() self.added_time = OrderedDict() # Store when components were added - self.collections = OrderedDict() # collection_name -> set of component_names + self.collections = OrderedDict() # collection_name -> set of component_names self.model_hooks = None self._auto_offload_enabled = False - def _lookup_ids(self, name=None, collection=None, load_id=None, components: OrderedDict = None): """ Lookup component_ids by name, collection, or load_id. @@ -276,7 +274,6 @@ def _id_to_name(component_id: str): return "_".join(component_id.split("_")[:-1]) def add(self, name, component, collection: Optional[str] = None): - component_id = f"{name}_{uuid.uuid4()}" # check for duplicated components @@ -284,17 +281,14 @@ def add(self, name, component, collection: Optional[str] = None): if comp == component: comp_name = self._id_to_name(comp_id) if comp_name == name: - logger.warning( - f"component '{name}' already exists as '{comp_id}'" - ) + logger.warning(f"component '{name}' already exists as '{comp_id}'") component_id = comp_id break else: logger.warning( f"Adding component '{name}' as '{component_id}', but it is duplicate of '{comp_id}'" f"To remove a duplicate, call `components_manager.remove('')`." - ) - + ) # check for duplicated load_id and warn (we do not delete for you) if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null": @@ -330,9 +324,7 @@ def add(self, name, component, collection: Optional[str] = None): return component_id - def remove(self, component_id: str = None): - if component_id not in self.components: logger.warning(f"Component '{component_id}' not found in ComponentsManager") return @@ -351,15 +343,21 @@ def remove(self, component_id: str = None): component.to("cpu") del component import gc + gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() - def get(self, names: Union[str, List[str]] = None, collection: Optional[str] = None, load_id: Optional[str] = None, - as_name_component_tuples: bool = False): + def get( + self, + names: Union[str, List[str]] = None, + collection: Optional[str] = None, + load_id: Optional[str] = None, + as_name_component_tuples: bool = False, + ): """ Select components by name with simple pattern matching. - + Args: names: Component name(s) or pattern(s) Patterns: @@ -376,10 +374,10 @@ def get(self, names: Union[str, List[str]] = None, collection: Optional[str] = N load_id: Optional load_id to filter by as_name_component_tuples: If True, returns a list of (name, component) tuples using base names instead of a dictionary with component IDs as keys - + Returns: - Dictionary mapping component IDs to components, - or list of (base_name, component) tuples if as_name_component_tuples=True + Dictionary mapping component IDs to components or list of (base_name, component) tuples if + as_name_component_tuples=True """ selected_ids = self._lookup_ids(collection=collection, load_id=load_id) @@ -387,10 +385,10 @@ def get(self, names: Union[str, List[str]] = None, collection: Optional[str] = N # Helper to extract base name from component_id def get_base_name(component_id): - parts = component_id.split('_') + parts = component_id.split("_") # If the last part looks like a UUID, remove it - if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]: - return '_'.join(parts[:-1]) + if len(parts) > 1 and len(parts[-1]) >= 8 and "-" in parts[-1]: + return "_".join(parts[:-1]) return component_id if names is None: @@ -405,7 +403,7 @@ def get_base_name(component_id): def matches_pattern(component_id, pattern, exact_match=False): """ Helper function to check if a component matches a pattern based on its base name. - + Args: component_id: The component ID to check pattern: The pattern to match against @@ -418,13 +416,13 @@ def matches_pattern(component_id, pattern, exact_match=False): return pattern == base_name # Prefix match (ends with *) - elif pattern.endswith('*'): + elif pattern.endswith("*"): prefix = pattern[:-1] return base_name.startswith(prefix) # Contains match (starts with *) - elif pattern.startswith('*'): - search = pattern[1:-1] if pattern.endswith('*') else pattern[1:] + elif pattern.startswith("*"): + search = pattern[1:-1] if pattern.endswith("*") else pattern[1:] return search in base_name # Exact match (no wildcards) @@ -433,18 +431,18 @@ def matches_pattern(component_id, pattern, exact_match=False): if isinstance(names, str): # Check if this is a "not" pattern - is_not_pattern = names.startswith('!') + is_not_pattern = names.startswith("!") if is_not_pattern: names = names[1:] # Remove the ! prefix # Handle OR patterns (containing |) - if '|' in names: - terms = names.split('|') + if "|" in names: + terms = names.split("|") matches = {} for comp_id, comp in components.items(): # For OR patterns with exact names (no wildcards), we do exact matching on base names - exact_match = all(not (term.startswith('*') or term.endswith('*')) for term in terms) + exact_match = all(not (term.startswith("*") or term.endswith("*")) for term in terms) # Check if any of the terms match this component should_include = any(matches_pattern(comp_id, term, exact_match) for term in terms) @@ -464,20 +462,24 @@ def matches_pattern(component_id, pattern, exact_match=False): elif any(names == base_name for base_name in base_names.values()): # Find all components with this base name matches = { - comp_id: comp for comp_id, comp in components.items() + comp_id: comp + for comp_id, comp in components.items() if (base_names[comp_id] == names) != is_not_pattern } if is_not_pattern: - logger.info(f"Getting all components except those with base name '{names}': {list(matches.keys())}") + logger.info( + f"Getting all components except those with base name '{names}': {list(matches.keys())}" + ) else: logger.info(f"Getting components with base name '{names}': {list(matches.keys())}") # Prefix match (ends with *) - elif names.endswith('*'): + elif names.endswith("*"): prefix = names[:-1] matches = { - comp_id: comp for comp_id, comp in components.items() + comp_id: comp + for comp_id, comp in components.items() if base_names[comp_id].startswith(prefix) != is_not_pattern } if is_not_pattern: @@ -486,10 +488,11 @@ def matches_pattern(component_id, pattern, exact_match=False): logger.info(f"Getting components starting with '{prefix}': {list(matches.keys())}") # Contains match (starts with *) - elif names.startswith('*'): - search = names[1:-1] if names.endswith('*') else names[1:] + elif names.startswith("*"): + search = names[1:-1] if names.endswith("*") else names[1:] matches = { - comp_id: comp for comp_id, comp in components.items() + comp_id: comp + for comp_id, comp in components.items() if (search in base_names[comp_id]) != is_not_pattern } if is_not_pattern: @@ -500,7 +503,8 @@ def matches_pattern(component_id, pattern, exact_match=False): # Substring match (no wildcards, but not an exact component name) elif any(names in base_name for base_name in base_names.values()): matches = { - comp_id: comp for comp_id, comp in components.items() + comp_id: comp + for comp_id, comp in components.items() if (names in base_names[comp_id]) != is_not_pattern } if is_not_pattern: @@ -533,7 +537,7 @@ def matches_pattern(component_id, pattern, exact_match=False): else: raise ValueError(f"Invalid type for names: {type(names)}") - def enable_auto_cpu_offload(self, device: Union[str, int, torch.device]="cuda", memory_reserve_margin="3GB"): + def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = "cuda", memory_reserve_margin="3GB"): for name, component in self.components.items(): if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"): remove_hook_from_module(component, recurse=True) @@ -573,18 +577,19 @@ def disable_auto_cpu_offload(self): self._auto_offload_enabled = False # YiYi TODO: add quantization info - def get_model_info(self, component_id: str, fields: Optional[Union[str, List[str]]] = None) -> Optional[Dict[str, Any]]: + def get_model_info( + self, component_id: str, fields: Optional[Union[str, List[str]]] = None + ) -> Optional[Dict[str, Any]]: """Get comprehensive information about a component. - + Args: component_id: Name of the component to get info for fields: Optional field(s) to return. Can be a string for single field or list of fields. If None, returns all fields. - + Returns: - Dictionary containing requested component metadata. - If fields is specified, returns only those fields. - If a single field is requested as string, returns just that field's value. + Dictionary containing requested component metadata. If fields is specified, returns only those fields. If a + single field is requested as string, returns just that field's value. """ if component_id not in self.components: raise ValueError(f"Component '{component_id}' not found in ComponentsManager") @@ -595,7 +600,8 @@ def get_model_info(self, component_id: str, fields: Optional[Union[str, List[str info = { "model_id": component_id, "added_time": self.added_time[component_id], - "collection": ", ".join([coll for coll, comps in self.collections.items() if component_id in comps]) or None, + "collection": ", ".join([coll for coll, comps in self.collections.items() if component_id in comps]) + or None, } # Additional info for torch.nn.Module components @@ -606,13 +612,15 @@ def get_model_info(self, component_id: str, fields: Optional[Union[str, List[str if has_hook and hasattr(component._hf_hook, "execution_device"): execution_device = component._hf_hook.execution_device - info.update({ - "class_name": component.__class__.__name__, - "size_gb": get_memory_footprint(component) / (1024**3), - "adapters": None, # Default to None - "has_hook": has_hook, - "execution_device": execution_device, - }) + info.update( + { + "class_name": component.__class__.__name__, + "size_gb": get_memory_footprint(component) / (1024**3), + "adapters": None, # Default to None + "has_hook": has_hook, + "execution_device": execution_device, + } + ) # Get adapters if applicable if hasattr(component, "peft_config"): @@ -649,10 +657,10 @@ def __repr__(self): def get_simple_name(name): # Extract the base name by splitting on underscore and taking first part # This assumes names are in format "name_uuid" - parts = name.split('_') + parts = name.split("_") # If we have at least 2 parts and the last part looks like a UUID, remove it - if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]: - return '_'.join(parts[:-1]) + if len(parts) > 1 and len(parts[-1]) >= 8 and "-" in parts[-1]: + return "_".join(parts[:-1]) return name # Extract load_id if available @@ -664,10 +672,10 @@ def get_load_id(component): # Format device info compactly def format_device(component, info): if not info["has_hook"]: - return str(getattr(component, 'device', 'N/A')) + return str(getattr(component, "device", "N/A")) else: - device = str(getattr(component, 'device', 'N/A')) - exec_device = str(info['execution_device'] or 'N/A') + device = str(getattr(component, "device", "N/A")) + exec_device = str(info["execution_device"] or "N/A") return f"{device}({exec_device})" # Get all simple names to calculate width @@ -702,7 +710,7 @@ def format_device(component, info): "dtype": 15, "size": 10, "load_id": max_load_id_len, - "collection": max_collection_len + "collection": max_collection_len, } # Create the header lines @@ -791,7 +799,7 @@ def format_device(component, info): def from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs): """ Load components from a pretrained model and add them to the manager. - + Args: pretrained_model_name_or_path (str): The path or identifier of the pretrained model prefix (str, optional): Prefix to add to all component names loaded from this model. @@ -802,6 +810,7 @@ def from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = # YiYi TODO: extend AutoModel to support non-diffusers models if subfolder: from ..models import AutoModel + component = AutoModel.from_pretrained(pretrained_model_name_or_path, subfolder=subfolder, **kwargs) component_name = f"{prefix}_{subfolder}" if prefix else subfolder if component_name not in self.components: @@ -814,9 +823,9 @@ def from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = ) else: from ..pipelines.pipeline_utils import DiffusionPipeline + pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs) for name, component in pipe.components.items(): - if component is None: continue @@ -832,18 +841,24 @@ def from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')" ) - def get_one(self, component_id: Optional[str] = None, name: Optional[str] = None, collection: Optional[str] = None, load_id: Optional[str] = None) -> Any: + def get_one( + self, + component_id: Optional[str] = None, + name: Optional[str] = None, + collection: Optional[str] = None, + load_id: Optional[str] = None, + ) -> Any: """ Get a single component by name. Raises an error if multiple components match or none are found. - + Args: name: Component name or pattern collection: Optional collection to filter by load_id: Optional load_id to filter by - + Returns: A single component - + Raises: ValueError: If no components match or multiple components match """ @@ -866,20 +881,18 @@ def get_one(self, component_id: Optional[str] = None, name: Optional[str] = None return next(iter(results.values())) + def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]: """Summarizes a dictionary by finding common prefixes that share the same value. - - For a dictionary with dot-separated keys like: - { + + For a dictionary with dot-separated keys like: { 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': [0.6], 'down_blocks.1.attentions.1.transformer_blocks.1.attn2.processor': [0.6], 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': [0.3], } - - Returns a dictionary where keys are the shortest common prefixes and values are their shared values: - { - 'down_blocks': [0.6], - 'up_blocks': [0.3] + + Returns a dictionary where keys are the shortest common prefixes and values are their shared values: { + 'down_blocks': [0.6], 'up_blocks': [0.3] } """ # First group by values - convert lists to tuples to make them hashable @@ -898,7 +911,7 @@ def find_common_prefix(keys: List[str]) -> str: return keys[0] # Split all keys into parts - key_parts = [k.split('.') for k in keys] + key_parts = [k.split(".") for k in keys] # Find how many initial parts are common common_length = 0 @@ -912,7 +925,7 @@ def find_common_prefix(keys: List[str]) -> str: return "" # Return the common prefix - return '.'.join(key_parts[0][:common_length]) + return ".".join(key_parts[0][:common_length]) # Create summary by finding common prefixes for each value group summary = {} diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 0d7bec5a5cc2..24d6e4caec13 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -62,7 +62,6 @@ ) - @dataclass class PipelineState: """ @@ -77,7 +76,7 @@ class PipelineState: def add_input(self, key: str, value: Any, kwargs_type: str = None): """ Add an input to the pipeline state with optional metadata. - + Args: key (str): The key for the input value (Any): The input value @@ -93,7 +92,7 @@ def add_input(self, key: str, value: Any, kwargs_type: str = None): def add_intermediate(self, key: str, value: Any, kwargs_type: str = None): """ Add an intermediate value to the pipeline state with optional metadata. - + Args: key (str): The key for the intermediate value value (Any): The intermediate value @@ -117,10 +116,10 @@ def get_inputs(self, keys: List[str], default: Any = None) -> Dict[str, Any]: def get_inputs_kwargs(self, kwargs_type: str) -> Dict[str, Any]: """ Get all inputs with matching kwargs_type. - + Args: kwargs_type (str): The kwargs_type to filter by - + Returns: Dict[str, Any]: Dictionary of inputs with matching kwargs_type """ @@ -130,10 +129,10 @@ def get_inputs_kwargs(self, kwargs_type: str) -> Dict[str, Any]: def get_intermediates_kwargs(self, kwargs_type: str) -> Dict[str, Any]: """ Get all intermediates with matching kwargs_type. - + Args: kwargs_type (str): The kwargs_type to filter by - + Returns: Dict[str, Any]: Dictionary of intermediates with matching kwargs_type """ @@ -180,6 +179,7 @@ class BlockState: """ Container for block state data with attribute access and formatted representation. """ + def __init__(self, **kwargs): for key, value in kwargs.items(): setattr(self, key, value) @@ -195,11 +195,11 @@ def __setitem__(self, key: str, value: Any): def as_dict(self): """ Convert BlockState to a dictionary. - + Returns: Dict[str, Any]: Dictionary containing all attributes of the BlockState """ - return {key: value for key, value in self.__dict__.items()} + return dict(self.__dict__.items()) def __repr__(self): def format_value(v): @@ -227,7 +227,12 @@ def format_value(v): for k, val in v.items(): if hasattr(val, "shape") and hasattr(val, "dtype"): formatted_dict[k] = f"Tensor(shape={val.shape}, dtype={val.dtype})" - elif isinstance(val, list) and len(val) > 0 and hasattr(val[0], "shape") and hasattr(val[0], "dtype"): + elif ( + isinstance(val, list) + and len(val) > 0 + and hasattr(val[0], "shape") + and hasattr(val[0], "dtype") + ): shapes = [t.shape for t in val] formatted_dict[k] = f"List[{len(val)}] of Tensors with shapes {shapes}" else: @@ -243,7 +248,8 @@ def format_value(v): class ModularPipelineBlocks(ConfigMixin): """ - Base class for all Pipeline Blocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks, LoopSequentialPipelineBlocks + Base class for all Pipeline Blocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks, + LoopSequentialPipelineBlocks """ config_name = "config.json" @@ -257,7 +263,6 @@ def _get_signature_keys(cls, obj): return expected_modules, optional_parameters - @classmethod def from_pretrained( cls, @@ -303,7 +308,7 @@ def from_pretrained( return block_cls(**block_kwargs) - def save_pretrained(self, save_directory, push_to_hub = False, **kwargs): + def save_pretrained(self, save_directory, push_to_hub=False, **kwargs): # TODO: factor out this logic. cls_name = self.__class__.__name__ @@ -317,7 +322,12 @@ def save_pretrained(self, save_directory, push_to_hub = False, **kwargs): config = dict(self.config) self._internal_dict = FrozenDict(config) - def init_pipeline(self, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None): + def init_pipeline( + self, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, + component_manager: Optional[ComponentsManager] = None, + collection: Optional[str] = None, + ): """ create a ModularLoader, optionally accept modular_repo to load from hub. """ @@ -331,13 +341,17 @@ def init_pipeline(self, pretrained_model_name_or_path: Optional[Union[str, os.Pa # Create the loader with the updated specs specs = component_specs + config_specs - loader = loader_class(specs=specs, pretrained_model_name_or_path=pretrained_model_name_or_path, component_manager=component_manager, collection=collection) + loader = loader_class( + specs=specs, + pretrained_model_name_or_path=pretrained_model_name_or_path, + component_manager=component_manager, + collection=collection, + ) modular_pipeline = ModularPipeline(blocks=self, loader=loader) return modular_pipeline class PipelineBlock(ModularPipelineBlocks): - model_name = None @property @@ -354,7 +368,6 @@ def expected_components(self) -> List[ComponentSpec]: def expected_configs(self) -> List[ConfigSpec]: return [] - @property def inputs(self) -> List[InputParam]: """List of input parameters. Must be implemented by subclasses.""" @@ -390,7 +403,6 @@ def _get_required_inputs(self): def required_inputs(self) -> List[str]: return self._get_required_inputs() - def _get_required_intermediates_inputs(self): input_names = [] for input_param in self.intermediates_inputs: @@ -404,7 +416,6 @@ def _get_required_intermediates_inputs(self): def required_intermediates_inputs(self) -> List[str]: return self._get_required_intermediates_inputs() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: raise NotImplementedError("__call__ method must be implemented in subclasses") @@ -413,14 +424,14 @@ def __repr__(self): base_class = self.__class__.__bases__[0].__name__ # Format description with proper indentation - desc_lines = self.description.split('\n') + desc_lines = self.description.split("\n") desc = [] # First line with "Description:" label desc.append(f" Description: {desc_lines[0]}") # Subsequent lines with proper indentation if len(desc_lines) > 1: desc.extend(f" {line}" for line in desc_lines[1:]) - desc = '\n'.join(desc) + '\n' + desc = "\n".join(desc) + "\n" # Components section - use format_components with add_empty_lines=False expected_components = getattr(self, "expected_components", []) @@ -437,20 +448,12 @@ def __repr__(self): inputs = "Inputs:\n " + inputs_str # Intermediates section - intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs) - intermediates = f"Intermediates:\n{intermediates_str}" - - return ( - f"{class_name}(\n" - f" Class: {base_class}\n" - f"{desc}" - f"{components}\n" - f"{configs}\n" - f" {inputs}\n" - f" {intermediates}\n" - f")" + intermediates_str = format_intermediates_short( + self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs ) + intermediates = f"Intermediates:\n{intermediates_str}" + return f"{class_name}(\n Class: {base_class}\n{desc}{components}\n{configs}\n {inputs}\n {intermediates}\n)" @property def doc(self): @@ -461,10 +464,9 @@ def doc(self): self.description, class_name=self.__class__.__name__, expected_components=self.expected_components, - expected_configs=self.expected_configs + expected_configs=self.expected_configs, ) - # YiYi TODO: input and inteermediate inputs with same name? should warn? def get_block_state(self, state: PipelineState) -> dict: """Get all inputs and intermediates in one dictionary""" @@ -544,13 +546,13 @@ def add_block_state(self, state: PipelineState, block_state: BlockState): def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]: """ - Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if - current default value is None and new default value is not None. Warns if multiple non-None default values - exist for the same input. + Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if current + default value is None and new default value is not None. Warns if multiple non-None default values exist for the + same input. Args: named_input_lists: List of tuples containing (block_name, input_param_list) pairs - + Returns: List[InputParam]: Combined list of unique InputParam objects """ @@ -565,9 +567,11 @@ def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> Li input_name = input_param.name if input_name in combined_dict: current_param = combined_dict[input_name] - if (current_param.default is not None and - input_param.default is not None and - current_param.default != input_param.default): + if ( + current_param.default is not None + and input_param.default is not None + and current_param.default != input_param.default + ): warnings.warn( f"Multiple different default values found for input '{input_name}': " f"{current_param.default} (from block '{value_sources[input_name]}') and " @@ -582,14 +586,15 @@ def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> Li return list(combined_dict.values()) + def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> List[OutputParam]: """ - Combines multiple lists of OutputParam objects from different blocks. For duplicate outputs, - keeps the first occurrence of each output name. + Combines multiple lists of OutputParam objects from different blocks. For duplicate outputs, keeps the first + occurrence of each output name. Args: named_output_lists: List of tuples containing (block_name, output_param_list) pairs - + Returns: List[OutputParam]: Combined list of unique OutputParam objects """ @@ -597,7 +602,9 @@ def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> for block_name, outputs in named_output_lists: for output_param in outputs: - if (output_param.name not in combined_dict) or (combined_dict[output_param.name].kwargs_type is None and output_param.kwargs_type is not None): + if (output_param.name not in combined_dict) or ( + combined_dict[output_param.name].kwargs_type is None and output_param.kwargs_type is not None + ): combined_dict[output_param.name] = output_param return list(combined_dict.values()) @@ -623,15 +630,15 @@ def __init__(self): blocks[block_name] = block_cls() self.blocks = blocks if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)): - raise ValueError(f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same.") + raise ValueError( + f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same." + ) default_blocks = [t for t in self.block_trigger_inputs if t is None] # can only have 1 or 0 default block, and has to put in the last # the order of blocksmatters here because the first block with matching trigger will be dispatched # e.g. blocks = [inpaint, img2img] and block_trigger_inputs = ["mask", "image"] # if both mask and image are provided, it is inpaint; if only image is provided, it is img2img - if len(default_blocks) > 1 or ( - len(default_blocks) == 1 and self.block_trigger_inputs[-1] is not None - ): + if len(default_blocks) > 1 or (len(default_blocks) == 1 and self.block_trigger_inputs[-1] is not None): raise ValueError( f"In {self.__class__.__name__}, exactly one None must be specified as the last element " "in block_trigger_inputs." @@ -668,7 +675,6 @@ def expected_configs(self): expected_configs.append(config) return expected_configs - @property def required_inputs(self) -> List[str]: if None not in self.block_trigger_inputs: @@ -699,7 +705,6 @@ def required_intermediates_inputs(self) -> List[str]: return list(required_by_all) - # YiYi TODO: add test for this @property def inputs(self) -> List[Tuple[str, Any]]: @@ -713,7 +718,6 @@ def inputs(self) -> List[Tuple[str, Any]]: input_param.required = False return combined_inputs - @property def intermediates_inputs(self) -> List[str]: named_inputs = [(name, block.intermediates_inputs) for name, block in self.blocks.items()] @@ -769,21 +773,22 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: def _get_trigger_inputs(self): """ - Returns a set of all unique trigger input values found in the blocks. - Returns: Set[str] containing all unique block_trigger_inputs values + Returns a set of all unique trigger input values found in the blocks. Returns: Set[str] containing all unique + block_trigger_inputs values """ + def fn_recursive_get_trigger(blocks): trigger_values = set() if blocks is not None: for name, block in blocks.items(): # Check if current block has trigger inputs(i.e. auto block) - if hasattr(block, 'block_trigger_inputs') and block.block_trigger_inputs is not None: + if hasattr(block, "block_trigger_inputs") and block.block_trigger_inputs is not None: # Add all non-None values from the trigger inputs list trigger_values.update(t for t in block.block_trigger_inputs if t is not None) # If block has blocks, recursively check them - if hasattr(block, 'blocks'): + if hasattr(block, "blocks"): nested_triggers = fn_recursive_get_trigger(block.blocks) trigger_values.update(nested_triggers) @@ -802,12 +807,9 @@ def __repr__(self): class_name = self.__class__.__name__ base_class = self.__class__.__bases__[0].__name__ header = ( - f"{class_name}(\n Class: {base_class}\n" - if base_class and base_class != "object" - else f"{class_name}(\n" + f"{class_name}(\n Class: {base_class}\n" if base_class and base_class != "object" else f"{class_name}(\n" ) - if self.trigger_inputs: header += "\n" header += " " + "=" * 100 + "\n" @@ -819,14 +821,14 @@ def __repr__(self): header += " " + "=" * 100 + "\n\n" # Format description with proper indentation - desc_lines = self.description.split('\n') + desc_lines = self.description.split("\n") desc = [] # First line with "Description:" label desc.append(f" Description: {desc_lines[0]}") # Subsequent lines with proper indentation if len(desc_lines) > 1: desc.extend(f" {line}" for line in desc_lines[1:]) - desc = '\n'.join(desc) + '\n' + desc = "\n".join(desc) + "\n" # Components section - focus only on expected components expected_components = getattr(self, "expected_components", []) @@ -841,7 +843,7 @@ def __repr__(self): for i, (name, block) in enumerate(self.blocks.items()): # Get trigger input for this block trigger = None - if hasattr(self, 'block_to_trigger_map'): + if hasattr(self, "block_to_trigger_map"): trigger = self.block_to_trigger_map.get(name) # Format the trigger info if trigger is None: @@ -857,10 +859,10 @@ def __repr__(self): blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" # Add block description - desc_lines = block.description.split('\n') + desc_lines = block.description.split("\n") indented_desc = desc_lines[0] if len(desc_lines) > 1: - indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) + indented_desc += "\n" + "\n".join(" " + line for line in desc_lines[1:]) blocks_str += f" Description: {indented_desc}\n\n" # Build the representation with conditional sections @@ -879,7 +881,6 @@ def __repr__(self): return result - @property def doc(self): return make_doc_string( @@ -889,7 +890,7 @@ def doc(self): self.description, class_name=self.__class__.__name__, expected_components=self.expected_components, - expected_configs=self.expected_configs + expected_configs=self.expected_configs, ) @@ -897,10 +898,10 @@ class SequentialPipelineBlocks(ModularPipelineBlocks): """ A class that combines multiple pipeline block classes into one. When called, it will call each block in sequence. """ + block_classes = [] block_names = [] - @property def description(self): return "" @@ -909,7 +910,6 @@ def description(self): def model_name(self): return next(iter(self.blocks.values())).model_name - @property def expected_components(self): expected_components = [] @@ -931,10 +931,10 @@ def expected_configs(self): @classmethod def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "SequentialPipelineBlocks": """Creates a SequentialPipelineBlocks instance from a dictionary of blocks. - + Args: blocks_dict: Dictionary mapping block names to block classes or instances - + Returns: A new SequentialPipelineBlocks instance """ @@ -959,7 +959,6 @@ def __init__(self): blocks[block_name] = block_cls() self.blocks = blocks - @property def required_inputs(self) -> List[str]: # Get the first block from the dictionary @@ -1031,7 +1030,7 @@ def get_intermediates_inputs(self): def intermediates_outputs(self) -> List[str]: named_outputs = [] for name, block in self.blocks.items(): - inp_names = set([inp.name for inp in block.intermediates_inputs]) + inp_names = {inp.name for inp in block.intermediates_inputs} # so we only need to list new variables as intermediates_outputs, but if user wants to list these they modified it's still fine (a.k.a we don't enforce) # filter out them here so they do not end up as intermediates_outputs if name not in inp_names: @@ -1044,6 +1043,7 @@ def intermediates_outputs(self) -> List[str]: def outputs(self) -> List[str]: # return next(reversed(self.blocks.values())).intermediates_outputs return self.intermediates_outputs + @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: for block_name, block in self.blocks.items(): @@ -1061,21 +1061,22 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: def _get_trigger_inputs(self): """ - Returns a set of all unique trigger input values found in the blocks. - Returns: Set[str] containing all unique block_trigger_inputs values + Returns a set of all unique trigger input values found in the blocks. Returns: Set[str] containing all unique + block_trigger_inputs values """ + def fn_recursive_get_trigger(blocks): trigger_values = set() if blocks is not None: for name, block in blocks.items(): # Check if current block has trigger inputs(i.e. auto block) - if hasattr(block, 'block_trigger_inputs') and block.block_trigger_inputs is not None: + if hasattr(block, "block_trigger_inputs") and block.block_trigger_inputs is not None: # Add all non-None values from the trigger inputs list trigger_values.update(t for t in block.block_trigger_inputs if t is not None) # If block has blocks, recursively check them - if hasattr(block, 'blocks'): + if hasattr(block, "blocks"): nested_triggers = fn_recursive_get_trigger(block.blocks) trigger_values.update(nested_triggers) @@ -1090,23 +1091,24 @@ def trigger_inputs(self): def _traverse_trigger_blocks(self, trigger_inputs): # Convert trigger_inputs to a set for easier manipulation active_triggers = set(trigger_inputs) + def fn_recursive_traverse(block, block_name, active_triggers): result_blocks = OrderedDict() # sequential(include loopsequential) or PipelineBlock - if not hasattr(block, 'block_trigger_inputs'): - if hasattr(block, 'blocks'): + if not hasattr(block, "block_trigger_inputs"): + if hasattr(block, "blocks"): # sequential or LoopSequentialPipelineBlocks (keep traversing) for sub_block_name, sub_block in block.blocks.items(): blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers) blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers) - blocks_to_update = {f"{block_name}.{k}": v for k,v in blocks_to_update.items()} + blocks_to_update = {f"{block_name}.{k}": v for k, v in blocks_to_update.items()} result_blocks.update(blocks_to_update) else: # PipelineBlock result_blocks[block_name] = block # Add this block's output names to active triggers if defined - if hasattr(block, 'outputs'): + if hasattr(block, "outputs"): active_triggers.update(out.name for out in block.outputs) return result_blocks @@ -1114,28 +1116,25 @@ def fn_recursive_traverse(block, block_name, active_triggers): else: # Find first block_trigger_input that matches any value in our active_triggers this_block = None - matching_trigger = None for trigger_input in block.block_trigger_inputs: if trigger_input is not None and trigger_input in active_triggers: this_block = block.trigger_to_block_map[trigger_input] - matching_trigger = trigger_input break # If no matches found, try to get the default (None) block if this_block is None and None in block.block_trigger_inputs: this_block = block.trigger_to_block_map[None] - matching_trigger = None if this_block is not None: # sequential/auto (keep traversing) - if hasattr(this_block, 'blocks'): + if hasattr(this_block, "blocks"): result_blocks.update(fn_recursive_traverse(this_block, block_name, active_triggers)) else: # PipelineBlock result_blocks[block_name] = this_block # Add this block's output names to active triggers if defined # YiYi TODO: do we need outputs here? can it just be intermediate_outputs? can we get rid of outputs attribute? - if hasattr(this_block, 'outputs'): + if hasattr(this_block, "outputs"): active_triggers.update(out.name for out in this_block.outputs) return result_blocks @@ -1150,7 +1149,6 @@ def get_execution_blocks(self, *trigger_inputs): trigger_inputs_all = self.trigger_inputs if trigger_inputs is not None: - if not isinstance(trigger_inputs, (list, tuple, set)): trigger_inputs = [trigger_inputs] invalid_inputs = [x for x in trigger_inputs if x not in trigger_inputs_all] @@ -1172,12 +1170,9 @@ def __repr__(self): class_name = self.__class__.__name__ base_class = self.__class__.__bases__[0].__name__ header = ( - f"{class_name}(\n Class: {base_class}\n" - if base_class and base_class != "object" - else f"{class_name}(\n" + f"{class_name}(\n Class: {base_class}\n" if base_class and base_class != "object" else f"{class_name}(\n" ) - if self.trigger_inputs: header += "\n" header += " " + "=" * 100 + "\n" @@ -1189,14 +1184,14 @@ def __repr__(self): header += " " + "=" * 100 + "\n\n" # Format description with proper indentation - desc_lines = self.description.split('\n') + desc_lines = self.description.split("\n") desc = [] # First line with "Description:" label desc.append(f" Description: {desc_lines[0]}") # Subsequent lines with proper indentation if len(desc_lines) > 1: desc.extend(f" {line}" for line in desc_lines[1:]) - desc = '\n'.join(desc) + '\n' + desc = "\n".join(desc) + "\n" # Components section - focus only on expected components expected_components = getattr(self, "expected_components", []) @@ -1211,7 +1206,7 @@ def __repr__(self): for i, (name, block) in enumerate(self.blocks.items()): # Get trigger input for this block trigger = None - if hasattr(self, 'block_to_trigger_map'): + if hasattr(self, "block_to_trigger_map"): trigger = self.block_to_trigger_map.get(name) # Format the trigger info if trigger is None: @@ -1227,10 +1222,10 @@ def __repr__(self): blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" # Add block description - desc_lines = block.description.split('\n') + desc_lines = block.description.split("\n") indented_desc = desc_lines[0] if len(desc_lines) > 1: - indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) + indented_desc += "\n" + "\n".join(" " + line for line in desc_lines[1:]) blocks_str += f" Description: {indented_desc}\n\n" # Build the representation with conditional sections @@ -1249,7 +1244,6 @@ def __repr__(self): return result - @property def doc(self): return make_doc_string( @@ -1259,13 +1253,15 @@ def doc(self): self.description, class_name=self.__class__.__name__, expected_components=self.expected_components, - expected_configs=self.expected_configs + expected_configs=self.expected_configs, ) -#YiYi TODO: __repr__ + +# YiYi TODO: __repr__ class LoopSequentialPipelineBlocks(ModularPipelineBlocks): """ - A class that combines multiple pipeline block classes into a For Loop. When called, it will call each block in sequence. + A class that combines multiple pipeline block classes into a For Loop. When called, it will call each block in + sequence. """ model_name = None @@ -1300,7 +1296,6 @@ def loop_intermediates_outputs(self) -> List[OutputParam]: """List of intermediate output parameters. Must be implemented by subclasses.""" return [] - @property def loop_required_inputs(self) -> List[str]: input_names = [] @@ -1361,7 +1356,6 @@ def get_inputs(self): def inputs(self): return self.get_inputs() - # modified from SequentialPipelineBlocks to include loop_intermediates_inputs @property def intermediates_inputs(self): @@ -1372,7 +1366,6 @@ def intermediates_inputs(self): intermediates.append(loop_intermediate_input) return intermediates - # Copied from SequentialPipelineBlocks def get_intermediates_inputs(self): inputs = [] @@ -1394,7 +1387,6 @@ def get_intermediates_inputs(self): outputs.update(block_intermediates_outputs) return inputs - # modified from SequentialPipelineBlocks, if any additionan input required by the loop is required by the block @property def required_inputs(self) -> List[str]: @@ -1425,7 +1417,6 @@ def required_intermediates_inputs(self) -> List[str]: required_intermediates_inputs.append(input_param.name) return required_intermediates_inputs - # YiYi TODO: this need to be thought about more # modified from SequentialPipelineBlocks to include loop_intermediates_outputs @property @@ -1433,7 +1424,7 @@ def intermediates_outputs(self) -> List[str]: named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] combined_outputs = combine_outputs(*named_outputs) for output in self.loop_intermediates_outputs: - if output.name not in set([output.name for output in combined_outputs]): + if output.name not in {output.name for output in combined_outputs}: combined_outputs.append(output) return combined_outputs @@ -1443,7 +1434,6 @@ def intermediates_outputs(self) -> List[str]: def outputs(self) -> List[str]: return next(reversed(self.blocks.values())).intermediates_outputs - def __init__(self): blocks = InsertableOrderedDict() for block_name, block_cls in zip(self.block_names, self.block_classes): @@ -1452,11 +1442,12 @@ def __init__(self): @classmethod def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "LoopSequentialPipelineBlocks": - """Creates a LoopSequentialPipelineBlocks instance from a dictionary of blocks. - + """ + Creates a LoopSequentialPipelineBlocks instance from a dictionary of blocks. + Args: blocks_dict: Dictionary mapping block names to block instances - + Returns: A new LoopSequentialPipelineBlocks instance """ @@ -1467,7 +1458,6 @@ def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "LoopSequentialPipelin return instance def loop_step(self, components, state: PipelineState, **kwargs): - for block_name, block in self.blocks.items(): try: components, state = block(components, state, **kwargs) @@ -1484,7 +1474,6 @@ def loop_step(self, components, state: PipelineState, **kwargs): def __call__(self, components, state: PipelineState) -> PipelineState: raise NotImplementedError("`__call__` method needs to be implemented by the subclass") - def get_block_state(self, state: PipelineState) -> dict: """Get all inputs and intermediates in one dictionary""" data = {} @@ -1554,7 +1543,6 @@ def add_block_state(self, state: PipelineState, block_state: BlockState): if current_value is not param: # Using identity comparison to check if object was modified state.add_intermediate(param_name, param, input_param.kwargs_type) - @property def doc(self): return make_doc_string( @@ -1564,30 +1552,28 @@ def doc(self): self.description, class_name=self.__class__.__name__, expected_components=self.expected_components, - expected_configs=self.expected_configs + expected_configs=self.expected_configs, ) # modified from SequentialPipelineBlocks, - #(does not need trigger_inputs related part so removed them, + # (does not need trigger_inputs related part so removed them, # do not need to support auto block for loop blocks) def __repr__(self): class_name = self.__class__.__name__ base_class = self.__class__.__bases__[0].__name__ header = ( - f"{class_name}(\n Class: {base_class}\n" - if base_class and base_class != "object" - else f"{class_name}(\n" + f"{class_name}(\n Class: {base_class}\n" if base_class and base_class != "object" else f"{class_name}(\n" ) # Format description with proper indentation - desc_lines = self.description.split('\n') + desc_lines = self.description.split("\n") desc = [] # First line with "Description:" label desc.append(f" Description: {desc_lines[0]}") # Subsequent lines with proper indentation if len(desc_lines) > 1: desc.extend(f" {line}" for line in desc_lines[1:]) - desc = '\n'.join(desc) + '\n' + desc = "\n".join(desc) + "\n" # Components section - focus only on expected components expected_components = getattr(self, "expected_components", []) @@ -1600,15 +1586,14 @@ def __repr__(self): # Blocks section - moved to the end with simplified format blocks_str = " Blocks:\n" for i, (name, block) in enumerate(self.blocks.items()): - # For SequentialPipelineBlocks, show execution order blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" # Add block description - desc_lines = block.description.split('\n') + desc_lines = block.description.split("\n") indented_desc = desc_lines[0] if len(desc_lines) > 1: - indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) + indented_desc += "\n" + "\n".join(" " + line for line in desc_lines[1:]) blocks_str += f" Description: {indented_desc}\n\n" # Build the representation with conditional sections @@ -1657,33 +1642,35 @@ class ModularLoader(ConfigMixin, PushToHubMixin): Base class for all Modular pipelines loaders. """ + config_name = "modular_model_index.json" hf_device_map = None - def register_components(self, **kwargs): """ Register components with their corresponding specifications. - + This method is responsible for: 1. Sets component objects as attributes on the loader (e.g., self.unet = unet) 2. Updates the modular_model_index.json configuration for serialization 4. Adds components to the component manager if one is attached - + This method is called when: - - Components are first initialized in __init__: - - from_pretrained components not loaded during __init__ so they are registered as None; + - Components are first initialized in __init__: + - from_pretrained components not loaded during __init__ so they are registered as None; - non from_pretrained components are created during __init__ and registered as the object itself - - Components are updated with the `update()` method: e.g. loader.update(unet=unet) or loader.update(guider=guider_spec) + - Components are updated with the `update()` method: e.g. loader.update(unet=unet) or + loader.update(guider=guider_spec) - (from_pretrained) Components are loaded with the `load()` method: e.g. loader.load(component_names=["unet"]) - + Args: **kwargs: Keyword arguments where keys are component names and values are component objects. E.g., register_components(unet=unet_model, text_encoder=encoder_model) - + Notes: - Components must be created from ComponentSpec (have _diffusers_load_id attribute) - - When registering None for a component, it updates the modular_model_index.json config but sets attribute to None + - When registering None for a component, it updates the modular_model_index.json config but sets attribute + to None """ for name, module in kwargs.items(): # current component spec @@ -1701,7 +1688,7 @@ def register_components(self, **kwargs): if module is not None: # actual library and class name of the module - library, class_name = _fetch_class_library_tuple(module) # e.g. ("diffusers", "UNet2DConditionModel") + library, class_name = _fetch_class_library_tuple(module) # e.g. ("diffusers", "UNet2DConditionModel") # extract the loading spec from the updated component spec that'll be used as part of modular_model_index.json config # e.g. {"repo": "stabilityai/stable-diffusion-2-1", @@ -1731,7 +1718,9 @@ def register_components(self, **kwargs): current_module = getattr(self, name, None) # skip if the component is already registered with the same object if current_module is module: - logger.info(f"ModularLoader.register_components: {name} is already registered with same object, skipping") + logger.info( + f"ModularLoader.register_components: {name} is already registered with same object, skipping" + ) continue # warn if unregister @@ -1741,10 +1730,12 @@ def register_components(self, **kwargs): f"(was {current_module.__class__.__name__})" ) # same type, new instance → replace but send debug log - elif current_module is not None \ - and module is not None \ - and isinstance(module, current_module.__class__) \ - and current_module != module: + elif ( + current_module is not None + and module is not None + and isinstance(module, current_module.__class__) + and current_module != module + ): logger.debug( f"ModularLoader.register_components: replacing existing '{name}' " f"(same type {type(current_module).__name__}, new instance)" @@ -1758,21 +1749,22 @@ def register_components(self, **kwargs): if module is not None and module._diffusers_load_id != "null" and self._component_manager is not None: self._component_manager.add(name, module, self._collection) - - # YiYi TODO: add warning for passing multiple ComponentSpec/ConfigSpec with the same name - def __init__(self, specs: List[Union[ComponentSpec, ConfigSpec]], pretrained_model_name_or_path: Optional[str] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs): + def __init__( + self, + specs: List[Union[ComponentSpec, ConfigSpec]], + pretrained_model_name_or_path: Optional[str] = None, + component_manager: Optional[ComponentsManager] = None, + collection: Optional[str] = None, + **kwargs, + ): """ Initialize the loader with a list of component specs and config specs. """ self._component_manager = component_manager self._collection = collection - self._component_specs = { - spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ComponentSpec) - } - self._config_specs = { - spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ConfigSpec) - } + self._component_specs = {spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ComponentSpec)} + self._config_specs = {spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ConfigSpec)} # update component_specs and config_specs from modular_repo if pretrained_model_name_or_path is not None: @@ -1780,7 +1772,12 @@ def __init__(self, specs: List[Union[ComponentSpec, ConfigSpec]], pretrained_mod for name, value in config_dict.items(): # only update component_spec for from_pretrained components - if name in self._component_specs and self._component_specs[name].default_creation_method == "from_pretrained" and isinstance(value, (tuple, list)) and len(value) == 3: + if ( + name in self._component_specs + and self._component_specs[name].default_creation_method == "from_pretrained" + and isinstance(value, (tuple, list)) + and len(value) == 3 + ): library, class_name, component_spec_dict = value component_spec = self._dict_to_component_spec(name, component_spec_dict) self._component_specs[name] = component_spec @@ -1802,7 +1799,6 @@ def __init__(self, specs: List[Union[ComponentSpec, ConfigSpec]], pretrained_mod default_configs[name] = config_spec.default self.register_to_config(**default_configs) - @property def device(self) -> torch.device: r""" @@ -1840,7 +1836,6 @@ def _execution_device(self): return torch.device(module._hf_hook.execution_device) return self.device - @property def dtype(self) -> torch.dtype: r""" @@ -1855,62 +1850,60 @@ def dtype(self) -> torch.dtype: return torch.float32 - @property def components(self) -> Dict[str, Any]: # return only components we've actually set as attributes on self - return { - name: getattr(self, name) - for name in self._component_specs.keys() - if hasattr(self, name) - } + return {name: getattr(self, name) for name in self._component_specs.keys() if hasattr(self, name)} def update(self, **kwargs): """ Update components and configuration values after the loader has been instantiated. - + This method allows you to: 1. Replace existing components with new ones (e.g., updating the unet or text_encoder) 2. Update configuration values (e.g., changing requires_safety_checker flag) - + Args: **kwargs: Component objects or configuration values to update: - - Component objects: Must be created using ComponentSpec (e.g., `unet=new_unet, text_encoder=new_encoder`) - - Configuration values: Simple values to update configuration settings (e.g., `requires_safety_checker=False`) - - ComponentSpec objects: if passed a ComponentSpec object, only support from_config spec, will call create() method to create it - + - Component objects: Must be created using ComponentSpec (e.g., `unet=new_unet, + text_encoder=new_encoder`) + - Configuration values: Simple values to update configuration settings (e.g., + `requires_safety_checker=False`) + - ComponentSpec objects: if passed a ComponentSpec object, only support from_config spec, will call + create() method to create it + Raises: ValueError: If a component wasn't created using ComponentSpec (doesn't have `_diffusers_load_id` attribute) - + Examples: ```python # Update multiple components at once - loader.update( - unet=new_unet_model, - text_encoder=new_text_encoder - ) - + loader.update(unet=new_unet_model, text_encoder=new_text_encoder) + # Update configuration values - loader.update( - requires_safety_checker=False, - guidance_rescale=0.7 - ) - + loader.update(requires_safety_checker=False, guidance_rescale=0.7) + # Update both components and configs together - loader.update( - unet=new_unet_model, - requires_safety_checker=False - ) + loader.update(unet=new_unet_model, requires_safety_checker=False) # update with ComponentSpec objects loader.update( - guider=ComponentSpec(name="guider", type_hint=ClassifierFreeGuidance, config={"guidance_scale": 5.0}, default_creation_method="from_config") + guider=ComponentSpec( + name="guider", + type_hint=ClassifierFreeGuidance, + config={"guidance_scale": 5.0}, + default_creation_method="from_config", + ) ) ``` """ # extract component_specs_updates & config_specs_updates from `specs` - passed_component_specs = {k: kwargs.pop(k) for k in self._component_specs if k in kwargs and isinstance(kwargs[k], ComponentSpec)} - passed_components = {k: kwargs.pop(k) for k in self._component_specs if k in kwargs and not isinstance(kwargs[k], ComponentSpec)} + passed_component_specs = { + k: kwargs.pop(k) for k in self._component_specs if k in kwargs and isinstance(kwargs[k], ComponentSpec) + } + passed_components = { + k: kwargs.pop(k) for k in self._component_specs if k in kwargs and not isinstance(kwargs[k], ComponentSpec) + } passed_config_values = {k: kwargs.pop(k) for k in self._config_specs if k in kwargs} for name, component in passed_components.items(): @@ -1926,8 +1919,12 @@ def update(self, **kwargs): ) current_component_spec = self._component_specs[name] # warn if type changed - if current_component_spec.type_hint is not None and not isinstance(component, current_component_spec.type_hint): - logger.warning(f"ModularLoader.update: adding {name} with new type: {component.__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}") + if current_component_spec.type_hint is not None and not isinstance( + component, current_component_spec.type_hint + ): + logger.warning( + f"ModularLoader.update: adding {name} with new type: {component.__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}" + ) # update _component_specs based on the new component new_component_spec = ComponentSpec.from_component(name, component) self._component_specs[name] = new_component_spec @@ -1938,46 +1935,54 @@ def update(self, **kwargs): created_components = {} for name, component_spec in passed_component_specs.items(): if component_spec.default_creation_method == "from_pretrained": - raise ValueError("ComponentSpec object with default_creation_method == 'from_pretrained' is not supported in update() method") + raise ValueError( + "ComponentSpec object with default_creation_method == 'from_pretrained' is not supported in update() method" + ) created_components[name] = component_spec.create() current_component_spec = self._component_specs[name] # warn if type changed - if current_component_spec.type_hint is not None and not isinstance(created_components[name], current_component_spec.type_hint): - logger.warning(f"ModularLoader.update: adding {name} with new type: {created_components[name].__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}") + if current_component_spec.type_hint is not None and not isinstance( + created_components[name], current_component_spec.type_hint + ): + logger.warning( + f"ModularLoader.update: adding {name} with new type: {created_components[name].__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}" + ) # update _component_specs based on the user passed component_spec self._component_specs[name] = component_spec self.register_components(**passed_components, **created_components) - config_to_register = {} for name, new_value in passed_config_values.items(): - # e.g. requires_aesthetics_score = False self._config_specs[name].default = new_value config_to_register[name] = new_value self.register_to_config(**config_to_register) - # YiYi TODO: support map for additional from_pretrained kwargs def load(self, component_names: Optional[List[str]] = None, **kwargs): """ Load selectedcomponents from specs. - + Args: component_names: List of component names to load **kwargs: additional kwargs to be passed to `from_pretrained()`.Can be: - a single value to be applied to all components to be loaded, e.g. torch_dtype=torch.bfloat16 - a dict, e.g. torch_dtype={"unet": torch.bfloat16, "default": torch.float32} - - if potentially override ComponentSpec if passed a different loading field in kwargs, e.g. `repo`, `variant`, `revision`, etc. + - if potentially override ComponentSpec if passed a different loading field in kwargs, e.g. `repo`, + `variant`, `revision`, etc. """ # if not specific name, load all the components with default_creation_method == "from_pretrained" if component_names is None: - component_names = [name for name in self._component_specs.keys() if self._component_specs[name].default_creation_method == "from_pretrained"] + component_names = [ + name + for name in self._component_specs.keys() + if self._component_specs[name].default_creation_method == "from_pretrained" + ] elif not isinstance(component_names, list): component_names = [component_names] - components_to_load = set([name for name in component_names if name in self._component_specs]) - unknown_component_names = set([name for name in component_names if name not in self._component_specs]) + components_to_load = {name for name in component_names if name in self._component_specs} + unknown_component_names = {name for name in component_names if name not in self._component_specs} if len(unknown_component_names) > 0: logger.warning(f"Unknown components will be ignored: {unknown_component_names}") @@ -2063,7 +2068,6 @@ def to(self, *args, **kwargs) -> Self: from ..pipelines.pipeline_utils import _check_bnb_status from ..utils import is_accelerate_available, is_accelerate_version, is_hpu_available, is_transformers_version - dtype = kwargs.pop("dtype", None) device = kwargs.pop("device", None) silence_dtype_warnings = kwargs.pop("silence_dtype_warnings", False) @@ -2219,8 +2223,9 @@ def module_is_offloaded(module): # YiYi TODO: # 1. should support save some components too! currently only modular_model_index.json is saved # 2. maybe order the json file to make it more readable: configs first, then components - def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, spec_only: bool = True, **kwargs): - + def save_pretrained( + self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, spec_only: bool = True, **kwargs + ): component_names = list(self._component_specs.keys()) config_names = list(self._config_specs.keys()) self.register_to_config(_components_names=component_names, _configs_names=config_names) @@ -2230,11 +2235,16 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: config.pop("_configs_names", None) self._internal_dict = FrozenDict(config) - @classmethod @validate_hf_hub_args - def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], spec_only: bool = True, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs): - + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + spec_only: bool = True, + component_manager: Optional[ComponentsManager] = None, + collection: Optional[str] = None, + **kwargs, + ): config_dict = cls.load_config(pretrained_model_name_or_path, **kwargs) expected_component = set(config_dict.pop("_components_names")) expected_config = set(config_dict.pop("_configs_names")) @@ -2244,7 +2254,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P for name, value in config_dict.items(): if name in expected_component and isinstance(value, (tuple, list)) and len(value) == 3: library, class_name, component_spec_dict = value - # only pick up pretrained components from the repo + # only pick up pretrained components from the repo if component_spec_dict.get("repo", None) is not None: component_spec = cls._dict_to_component_spec(name, component_spec_dict) component_specs.append(component_spec) @@ -2254,12 +2264,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P return cls(component_specs + config_specs, component_manager=component_manager, collection=collection) - @staticmethod def _component_spec_to_dict(component_spec: ComponentSpec) -> Any: """ - Convert a ComponentSpec into a JSON‐serializable dict for saving in - `modular_model_index.json`. + Convert a ComponentSpec into a JSON‐serializable dict for saving in `modular_model_index.json`. This dict contains: - "type_hint": Tuple[str, str] @@ -2283,25 +2291,12 @@ def _component_spec_to_dict(component_spec: ComponentSpec) -> Any: Dict[str, Any]: A mapping suitable for JSON serialization. Example: - >>> from diffusers.pipelines.modular_pipeline_utils import ComponentSpec - >>> from diffusers.models.unet import UNet2DConditionModel - >>> spec = ComponentSpec( - ... name="unet", - ... type_hint=UNet2DConditionModel, - ... config=None, - ... repo="path/to/repo", - ... subfolder="subfolder", - ... variant=None, - ... revision=None, - ... default_creation_method="from_pretrained", - ... ) - >>> ModularLoader._component_spec_to_dict(spec) - { - "type_hint": ("diffusers.models.unet", "UNet2DConditionModel"), - "repo": "path/to/repo", - "subfolder": "subfolder", - "variant": None, - "revision": None, + >>> from diffusers.pipelines.modular_pipeline_utils import ComponentSpec >>> from diffusers.models.unet + import UNet2DConditionModel >>> spec = ComponentSpec( ... name="unet", ... type_hint=UNet2DConditionModel, + ... config=None, ... repo="path/to/repo", ... subfolder="subfolder", ... variant=None, ... revision=None, + ... default_creation_method="from_pretrained", ... ) >>> ModularLoader._component_spec_to_dict(spec) { + "type_hint": ("diffusers.models.unet", "UNet2DConditionModel"), "repo": "path/to/repo", "subfolder": + "subfolder", "variant": None, "revision": None, } """ if component_spec.type_hint is not None: @@ -2370,11 +2365,9 @@ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = if state is None: state = PipelineState() - # Make a copy of the input kwargs passed_kwargs = kwargs.copy() - # Add inputs to state, using defaults if not provided in the kwargs or the state # if same input already in the state, will override it if provided in the kwargs @@ -2412,7 +2405,6 @@ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = if output is None: return state - elif isinstance(output, str): return state.get_intermediate(output) @@ -2421,7 +2413,6 @@ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = else: raise ValueError(f"Output '{output}' is not a valid output type") - def load_components(self, component_names: Optional[List[str]] = None, **kwargs): self.loader.load(component_names=component_names, **kwargs) @@ -2430,16 +2421,28 @@ def update_components(self, **kwargs): @classmethod @validate_hf_hub_args - def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], trust_remote_code: Optional[bool] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs): - blocks = ModularPipelineBlocks.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs) - pipeline = blocks.init_pipeline(pretrained_model_name_or_path, component_manager=component_manager, collection=collection, **kwargs) + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + trust_remote_code: Optional[bool] = None, + component_manager: Optional[ComponentsManager] = None, + collection: Optional[str] = None, + **kwargs, + ): + blocks = ModularPipelineBlocks.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs + ) + pipeline = blocks.init_pipeline( + pretrained_model_name_or_path, component_manager=component_manager, collection=collection, **kwargs + ) return pipeline - def save_pretrained(self, save_directory: Optional[Union[str, os.PathLike]] = None, push_to_hub: bool = False, **kwargs): + def save_pretrained( + self, save_directory: Optional[Union[str, os.PathLike]] = None, push_to_hub: bool = False, **kwargs + ): self.blocks.save_pretrained(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) self.loader.save_pretrained(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) - @property def doc(self): return self.blocks.doc diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index 1b9874bb52bd..2d86c2540072 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -60,11 +60,11 @@ def __repr__(self): @dataclass class ComponentSpec: """Specification for a pipeline component. - + A component can be created in two ways: 1. From scratch using __init__ with a config dict 2. using `from_pretrained` - + Attributes: name: Name of the component type_hint: Type of the component (e.g. UNet2DConditionModel) @@ -76,6 +76,7 @@ class ComponentSpec: revision: Optional revision in repo default_creation_method: Preferred creation method - "from_config" or "from_pretrained" """ + name: Optional[str] = None type_hint: Optional[Type] = None description: Optional[str] = None @@ -87,7 +88,6 @@ class ComponentSpec: revision: Optional[str] = field(default=None, metadata={"loading": True}) default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained" - def __hash__(self): """Make ComponentSpec hashable, using load_id as the hash value.""" return hash((self.name, self.load_id, self.default_creation_method)) @@ -96,9 +96,11 @@ def __eq__(self, other): """Compare ComponentSpec objects based on name and load_id.""" if not isinstance(other, ComponentSpec): return False - return (self.name == other.name and - self.load_id == other.load_id and - self.default_creation_method == other.default_creation_method) + return ( + self.name == other.name + and self.load_id == other.load_id + and self.default_creation_method == other.default_creation_method + ) @classmethod def from_component(cls, name: str, component: Any) -> Any: @@ -125,22 +127,22 @@ def from_component(cls, name: str, component: Any) -> Any: load_spec = cls.decode_load_id(component._diffusers_load_id) - return cls(name=name, type_hint=type_hint, config=config, default_creation_method=default_creation_method, **load_spec) + return cls( + name=name, type_hint=type_hint, config=config, default_creation_method=default_creation_method, **load_spec + ) @classmethod def loading_fields(cls) -> List[str]: """ - Return the names of all loading‐related fields - (i.e. those whose field.metadata["loading"] is True). + Return the names of all loading‐related fields (i.e. those whose field.metadata["loading"] is True). """ return [f.name for f in fields(cls) if f.metadata.get("loading", False)] - @property def load_id(self) -> str: """ - Unique identifier for this spec's pretrained load, - composed of repo|subfolder|variant|revision (no empty segments). + Unique identifier for this spec's pretrained load, composed of repo|subfolder|variant|revision (no empty + segments). """ parts = [getattr(self, k) for k in self.loading_fields()] parts = ["null" if p is None else p for p in parts] @@ -150,21 +152,16 @@ def load_id(self) -> str: def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]: """ Decode a load_id string back into a dictionary of loading fields and values. - + Args: load_id: The load_id string to decode, format: "repo|subfolder|variant|revision" where None values are represented as "null" - + Returns: - Dict mapping loading field names to their values. e.g. - { - "repo": "path/to/repo", - "subfolder": "subfolder", - "variant": "variant", - "revision": "revision" - } - If a segment value is "null", it's replaced with None. - Returns None if load_id is "null" (indicating component not created with `load` method). + Dict mapping loading field names to their values. e.g. { + "repo": "path/to/repo", "subfolder": "subfolder", "variant": "variant", "revision": "revision" + } If a segment value is "null", it's replaced with None. Returns None if load_id is "null" (indicating + component not created with `load` method). """ # Get all loading fields in order @@ -185,7 +182,6 @@ def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]: return result - # YiYi TODO: I think we should only support ConfigMixin for this method (after we make guider and image_processors config mixin) # otherwise we cannot do spec -> spec.create() -> component -> ComponentSpec.from_component(component) # the config info is lost in the process @@ -194,9 +190,7 @@ def create(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **k """Create component using from_config with config.""" if self.type_hint is None or not isinstance(self.type_hint, type): - raise ValueError( - "`type_hint` is required when using from_config creation method." - ) + raise ValueError("`type_hint` is required when using from_config creation method.") config = config or self.config or {} @@ -230,11 +224,14 @@ def load(self, **kwargs) -> Any: # repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path repo = load_kwargs.pop("repo", None) if repo is None: - raise ValueError("`repo` info is required when using `load` method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)") + raise ValueError( + "`repo` info is required when using `load` method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)" + ) if self.type_hint is None: try: from diffusers import AutoModel + component = AutoModel.from_pretrained(repo, **load_kwargs, **kwargs) except Exception as e: raise ValueError(f"Unable to load {self.name} without `type_hint`: {e}") @@ -254,10 +251,10 @@ def load(self, **kwargs) -> Any: return component - @dataclass class ConfigSpec: """Specification for a pipeline configuration parameter.""" + name: str default: Any description: Optional[str] = None @@ -271,12 +268,13 @@ class ConfigSpec: @dataclass class InputParam: """Specification for an input parameter.""" + name: str = None type_hint: Any = None default: Any = None required: bool = False description: str = "" - kwargs_type: str = None # YiYi Notes: remove this feature (maybe) + kwargs_type: str = None # YiYi Notes: remove this feature (maybe) def __repr__(self): return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>" @@ -285,34 +283,33 @@ def __repr__(self): @dataclass class OutputParam: """Specification for an output parameter.""" + name: str type_hint: Any = None description: str = "" - kwargs_type: str = None # YiYi notes: remove this feature (maybe) + kwargs_type: str = None # YiYi notes: remove this feature (maybe) def __repr__(self): - return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>" + return ( + f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>" + ) def format_inputs_short(inputs): """ Format input parameters into a string representation, with required params first followed by optional ones. - + Args: inputs: List of input parameters with 'required' and 'name' attributes, and 'default' for optional params - + Returns: str: Formatted string of input parameters - + Example: - >>> inputs = [ - ... InputParam(name="prompt", required=True), - ... InputParam(name="image", required=True), - ... InputParam(name="guidance_scale", required=False, default=7.5), - ... InputParam(name="num_inference_steps", required=False, default=50) - ... ] - >>> format_inputs_short(inputs) - 'prompt, image, guidance_scale=7.5, num_inference_steps=50' + >>> inputs = [ ... InputParam(name="prompt", required=True), ... InputParam(name="image", required=True), ... + InputParam(name="guidance_scale", required=False, default=7.5), ... InputParam(name="num_inference_steps", + required=False, default=50) ... ] >>> format_inputs_short(inputs) 'prompt, image, guidance_scale=7.5, + num_inference_steps=50' """ required_inputs = [param for param in inputs if param.required] optional_inputs = [param for param in inputs if not param.required] @@ -330,18 +327,18 @@ def format_inputs_short(inputs): def format_intermediates_short(intermediates_inputs, required_intermediates_inputs, intermediates_outputs): """ Formats intermediate inputs and outputs of a block into a string representation. - + Args: intermediates_inputs: List of intermediate input parameters required_intermediates_inputs: List of required intermediate input names intermediates_outputs: List of intermediate output parameters - + Returns: str: Formatted string like: Intermediates: - inputs: Required(latents), dtype - - modified: latents # variables that appear in both inputs and outputs - - outputs: images # new outputs only + - modified: latents # variables that appear in both inputs and outputs + - outputs: images # new outputs only """ # Handle inputs input_parts = [] @@ -433,7 +430,7 @@ def wrap_text(text, indent, max_length): # Format parameter name and type type_str = get_type_str(param.type_hint) if param.type_hint != Any else "" # YiYi Notes: remove this line if we remove kwargs_type - name = f'**{param.kwargs_type}' if param.name is None and param.kwargs_type is not None else param.name + name = f"**{param.kwargs_type}" if param.name is None and param.kwargs_type is not None else param.name param_str = f"{param_indent}{name} (`{type_str}`" # Add optional tag and default value if parameter is an InputParam and optional @@ -446,11 +443,7 @@ def wrap_text(text, indent, max_length): # Add description on a new line with additional indentation and wrapping if param.description: - desc = re.sub( - r'\[(.*?)\]\((https?://[^\s\)]+)\)', - r'[\1](\2)', - param.description - ) + desc = re.sub(r"\[(.*?)\]\((https?://[^\s\)]+)\)", r"[\1](\2)", param.description) wrapped_desc = wrap_text(desc, desc_indent, max_line_length) param_str += f"\n{desc_indent}{wrapped_desc}" @@ -514,7 +507,9 @@ def format_components(components, indent_level=4, max_line_length=115, add_empty # Add each component with optional empty lines between them for i, component in enumerate(components): # Get type name, handling special cases - type_name = component.type_hint.__name__ if hasattr(component.type_hint, "__name__") else str(component.type_hint) + type_name = ( + component.type_hint.__name__ if hasattr(component.type_hint, "__name__") else str(component.type_hint) + ) component_desc = f"{component_indent}{component.name} (`{type_name}`)" if component.description: @@ -578,10 +573,18 @@ def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines return "\n".join(formatted_configs) -def make_doc_string(inputs, intermediates_inputs, outputs, description="", class_name=None, expected_components=None, expected_configs=None): +def make_doc_string( + inputs, + intermediates_inputs, + outputs, + description="", + class_name=None, + expected_components=None, + expected_configs=None, +): """ Generates a formatted documentation string describing the pipeline block's parameters and structure. - + Args: inputs: List of input parameters intermediates_inputs: List of intermediate input parameters @@ -590,9 +593,9 @@ def make_doc_string(inputs, intermediates_inputs, outputs, description="", class class_name (str, *optional*): Name of the class to include in the documentation expected_components (List[ComponentSpec], *optional*): List of expected components expected_configs (List[ConfigSpec], *optional*): List of expected configurations - + Returns: - str: A formatted string containing information about components, configs, call parameters, + str: A formatted string containing information about components, configs, call parameters, intermediate inputs/outputs, and final outputs. """ output = "" @@ -603,8 +606,8 @@ def make_doc_string(inputs, intermediates_inputs, outputs, description="", class # Add description if description: - desc_lines = description.strip().split('\n') - aligned_desc = '\n'.join(' ' + line for line in desc_lines) + desc_lines = description.strip().split("\n") + aligned_desc = "\n".join(" " + line for line in desc_lines) output += aligned_desc + "\n\n" # Add components section if provided diff --git a/src/diffusers/modular_pipelines/node_utils.py b/src/diffusers/modular_pipelines/node_utils.py index 4855a9bcfcd1..7c3c33d3f648 100644 --- a/src/diffusers/modular_pipelines/node_utils.py +++ b/src/diffusers/modular_pipelines/node_utils.py @@ -18,69 +18,222 @@ # YiYi Notes: this is actually for SDXL, put it here for now SDXL_INPUTS_SCHEMA = { - "prompt": InputParam("prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"), - "prompt_2": InputParam("prompt_2", type_hint=Union[str, List[str]], description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2"), - "negative_prompt": InputParam("negative_prompt", type_hint=Union[str, List[str]], description="The prompt or prompts not to guide the image generation"), - "negative_prompt_2": InputParam("negative_prompt_2", type_hint=Union[str, List[str]], description="The negative prompt or prompts for text_encoder_2"), - "cross_attention_kwargs": InputParam("cross_attention_kwargs", type_hint=Optional[dict], description="Kwargs dictionary passed to the AttentionProcessor"), - "clip_skip": InputParam("clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"), - "image": InputParam("image", type_hint=PipelineImageInput, required=True, description="The image(s) to modify for img2img or inpainting"), - "mask_image": InputParam("mask_image", type_hint=PipelineImageInput, required=True, description="Mask image for inpainting, white pixels will be repainted"), - "generator": InputParam("generator", type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], description="Generator(s) for deterministic generation"), + "prompt": InputParam( + "prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation" + ), + "prompt_2": InputParam( + "prompt_2", + type_hint=Union[str, List[str]], + description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2", + ), + "negative_prompt": InputParam( + "negative_prompt", + type_hint=Union[str, List[str]], + description="The prompt or prompts not to guide the image generation", + ), + "negative_prompt_2": InputParam( + "negative_prompt_2", + type_hint=Union[str, List[str]], + description="The negative prompt or prompts for text_encoder_2", + ), + "cross_attention_kwargs": InputParam( + "cross_attention_kwargs", + type_hint=Optional[dict], + description="Kwargs dictionary passed to the AttentionProcessor", + ), + "clip_skip": InputParam( + "clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder" + ), + "image": InputParam( + "image", + type_hint=PipelineImageInput, + required=True, + description="The image(s) to modify for img2img or inpainting", + ), + "mask_image": InputParam( + "mask_image", + type_hint=PipelineImageInput, + required=True, + description="Mask image for inpainting, white pixels will be repainted", + ), + "generator": InputParam( + "generator", + type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], + description="Generator(s) for deterministic generation", + ), "height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"), "width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"), - "num_images_per_prompt": InputParam("num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"), - "num_inference_steps": InputParam("num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"), - "timesteps": InputParam("timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"), - "sigmas": InputParam("sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"), - "denoising_end": InputParam("denoising_end", type_hint=Optional[float], description="Fraction of denoising process to complete before termination"), + "num_images_per_prompt": InputParam( + "num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt" + ), + "num_inference_steps": InputParam( + "num_inference_steps", type_hint=int, default=50, description="Number of denoising steps" + ), + "timesteps": InputParam( + "timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process" + ), + "sigmas": InputParam( + "sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process" + ), + "denoising_end": InputParam( + "denoising_end", + type_hint=Optional[float], + description="Fraction of denoising process to complete before termination", + ), # YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999 - "strength": InputParam("strength", type_hint=float, default=0.3, description="How much to transform the reference image"), - "denoising_start": InputParam("denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"), - "latents": InputParam("latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"), - "padding_mask_crop": InputParam("padding_mask_crop", type_hint=Optional[Tuple[int, int]], description="Size of margin in crop for image and mask"), - "original_size": InputParam("original_size", type_hint=Optional[Tuple[int, int]], description="Original size of the image for SDXL's micro-conditioning"), - "target_size": InputParam("target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"), - "negative_original_size": InputParam("negative_original_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on image resolution"), - "negative_target_size": InputParam("negative_target_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on target resolution"), - "crops_coords_top_left": InputParam("crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Top-left coordinates for SDXL's micro-conditioning"), - "negative_crops_coords_top_left": InputParam("negative_crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Negative conditioning crop coordinates"), - "aesthetic_score": InputParam("aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"), - "negative_aesthetic_score": InputParam("negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"), + "strength": InputParam( + "strength", type_hint=float, default=0.3, description="How much to transform the reference image" + ), + "denoising_start": InputParam( + "denoising_start", type_hint=Optional[float], description="Starting point of the denoising process" + ), + "latents": InputParam( + "latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation" + ), + "padding_mask_crop": InputParam( + "padding_mask_crop", + type_hint=Optional[Tuple[int, int]], + description="Size of margin in crop for image and mask", + ), + "original_size": InputParam( + "original_size", + type_hint=Optional[Tuple[int, int]], + description="Original size of the image for SDXL's micro-conditioning", + ), + "target_size": InputParam( + "target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning" + ), + "negative_original_size": InputParam( + "negative_original_size", + type_hint=Optional[Tuple[int, int]], + description="Negative conditioning based on image resolution", + ), + "negative_target_size": InputParam( + "negative_target_size", + type_hint=Optional[Tuple[int, int]], + description="Negative conditioning based on target resolution", + ), + "crops_coords_top_left": InputParam( + "crops_coords_top_left", + type_hint=Tuple[int, int], + default=(0, 0), + description="Top-left coordinates for SDXL's micro-conditioning", + ), + "negative_crops_coords_top_left": InputParam( + "negative_crops_coords_top_left", + type_hint=Tuple[int, int], + default=(0, 0), + description="Negative conditioning crop coordinates", + ), + "aesthetic_score": InputParam( + "aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image" + ), + "negative_aesthetic_score": InputParam( + "negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score" + ), "eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"), - "output_type": InputParam("output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"), - "ip_adapter_image": InputParam("ip_adapter_image", type_hint=PipelineImageInput, required=True, description="Image(s) to be used as IP adapter"), - "control_image": InputParam("control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"), - "control_guidance_start": InputParam("control_guidance_start", type_hint=Union[float, List[float]], default=0.0, description="When ControlNet starts applying"), - "control_guidance_end": InputParam("control_guidance_end", type_hint=Union[float, List[float]], default=1.0, description="When ControlNet stops applying"), - "controlnet_conditioning_scale": InputParam("controlnet_conditioning_scale", type_hint=Union[float, List[float]], default=1.0, description="Scale factor for ControlNet outputs"), - "guess_mode": InputParam("guess_mode", type_hint=bool, default=False, description="Enables ControlNet encoder to recognize input without prompts"), - "control_mode": InputParam("control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet") + "output_type": InputParam( + "output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)" + ), + "ip_adapter_image": InputParam( + "ip_adapter_image", + type_hint=PipelineImageInput, + required=True, + description="Image(s) to be used as IP adapter", + ), + "control_image": InputParam( + "control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition" + ), + "control_guidance_start": InputParam( + "control_guidance_start", + type_hint=Union[float, List[float]], + default=0.0, + description="When ControlNet starts applying", + ), + "control_guidance_end": InputParam( + "control_guidance_end", + type_hint=Union[float, List[float]], + default=1.0, + description="When ControlNet stops applying", + ), + "controlnet_conditioning_scale": InputParam( + "controlnet_conditioning_scale", + type_hint=Union[float, List[float]], + default=1.0, + description="Scale factor for ControlNet outputs", + ), + "guess_mode": InputParam( + "guess_mode", + type_hint=bool, + default=False, + description="Enables ControlNet encoder to recognize input without prompts", + ), + "control_mode": InputParam( + "control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet" + ), } SDXL_INTERMEDIATE_INPUTS_SCHEMA = { - "prompt_embeds": InputParam("prompt_embeds", type_hint=torch.Tensor, required=True, description="Text embeddings used to guide image generation"), - "negative_prompt_embeds": InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"), - "pooled_prompt_embeds": InputParam("pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"), - "negative_pooled_prompt_embeds": InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"), + "prompt_embeds": InputParam( + "prompt_embeds", + type_hint=torch.Tensor, + required=True, + description="Text embeddings used to guide image generation", + ), + "negative_prompt_embeds": InputParam( + "negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings" + ), + "pooled_prompt_embeds": InputParam( + "pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings" + ), + "negative_pooled_prompt_embeds": InputParam( + "negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings" + ), "batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"), "dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), - "preprocess_kwargs": InputParam("preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"), - "latents": InputParam("latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"), + "preprocess_kwargs": InputParam( + "preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor" + ), + "latents": InputParam( + "latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process" + ), "timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"), - "num_inference_steps": InputParam("num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"), - "latent_timestep": InputParam("latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"), - "image_latents": InputParam("image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"), + "num_inference_steps": InputParam( + "num_inference_steps", type_hint=int, required=True, description="Number of denoising steps" + ), + "latent_timestep": InputParam( + "latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep" + ), + "image_latents": InputParam( + "image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image" + ), "mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"), - "masked_image_latents": InputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"), - "add_time_ids": InputParam("add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"), - "negative_add_time_ids": InputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"), + "masked_image_latents": InputParam( + "masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting" + ), + "add_time_ids": InputParam( + "add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning" + ), + "negative_add_time_ids": InputParam( + "negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids" + ), "timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"), "noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"), "crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"), - "ip_adapter_embeds": InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"), - "negative_ip_adapter_embeds": InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"), - "images": InputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], required=True, description="Generated images") + "ip_adapter_embeds": InputParam( + "ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter" + ), + "negative_ip_adapter_embeds": InputParam( + "negative_ip_adapter_embeds", + type_hint=List[torch.Tensor], + description="Negative image embeddings for IP-Adapter", + ), + "images": InputParam( + "images", + type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], + required=True, + description="Generated images", + ), } SDXL_PARAM_SCHEMA = {**SDXL_INPUTS_SCHEMA, **SDXL_INTERMEDIATE_INPUTS_SCHEMA} @@ -99,7 +252,6 @@ "default": "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality", "display": "textarea", }, - "num_inference_steps": { "label": "Steps", "type": "int", @@ -146,7 +298,7 @@ }, } -DEFAULT_TYPE_MAPS ={ +DEFAULT_TYPE_MAPS = { "int": { "type": "int", "default": 0, @@ -182,8 +334,8 @@ def get_group_name(name, group_params_keys=DEFAULT_PARAMS_GROUPS_KEYS): """ - Get the group name for a given parameter name, if not part of a group, return None - e.g. "prompt_embeds" -> "text_embeds", "text_encoder" -> "text_encoders", "prompt" -> None + Get the group name for a given parameter name, if not part of a group, return None e.g. "prompt_embeds" -> + "text_embeds", "text_encoder" -> "text_encoders", "prompt" -> None """ if name is None: return None @@ -195,7 +347,6 @@ def get_group_name(name, group_params_keys=DEFAULT_PARAMS_GROUPS_KEYS): class ModularNode(ConfigMixin): - config_name = "node_config.json" @classmethod @@ -205,7 +356,9 @@ def from_pretrained( trust_remote_code: Optional[bool] = None, **kwargs, ): - blocks = ModularPipelineBlocks.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs) + blocks = ModularPipelineBlocks.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs + ) return cls(blocks, **kwargs) def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs): @@ -252,8 +405,8 @@ def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs): self.name_mapping[inp.name] = param else: # if not, check if it's in the SDXL input schema, if so, - # 1. use the type hint to determine the type - # 2. use the default param dict for the type e.g. if "steps" is a "int" type, {"steps": {"type": "int", "default": 0, "min": 0}} + # 1. use the type hint to determine the type + # 2. use the default param dict for the type e.g. if "steps" is a "int" type, {"steps": {"type": "int", "default": 0, "min": 0}} if inp.type_hint is not None: type_str = str(inp.type_hint).lower() else: @@ -270,7 +423,6 @@ def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs): # add the param dict to the inp_params dict input_params[inp.name] = param - component_params = {} for comp in self.blocks.expected_components: param = kwargs.pop(comp.name, None) @@ -352,7 +504,6 @@ def mellon_config(self): return self._convert_to_mellon_config() def _convert_to_mellon_config(self): - node = {} node["label"] = self.config.label node["category"] = self.config.category @@ -377,7 +528,6 @@ def _convert_to_mellon_config(self): else: logger.debug(f"Input param {mellon_name} already exists in node_param, skipping {inp_name}") - for comp_name, comp_param in self.config.component_params.items(): if comp_name in self.name_mapping: mellon_name = self.name_mapping[comp_name] @@ -397,7 +547,6 @@ def _convert_to_mellon_config(self): else: logger.debug(f"Component param {comp_param} already exists in node_param, skipping {comp_name}") - for out_name, out_param in self.config.output_params.items(): if out_name in self.name_mapping: mellon_name = self.name_mapping[out_name] @@ -422,10 +571,10 @@ def _convert_to_mellon_config(self): def save_mellon_config(self, file_path): """ Save the Mellon configuration to a JSON file. - + Args: file_path (str or Path): Path where the JSON file will be saved - + Returns: Path: Path to the saved config file """ @@ -435,13 +584,10 @@ def save_mellon_config(self, file_path): os.makedirs(file_path.parent, exist_ok=True) # Create a combined dictionary with module definition and name mapping - config = { - "module": self.mellon_config, - "name_mapping": self.name_mapping - } + config = {"module": self.mellon_config, "name_mapping": self.name_mapping} # Save the config to file - with open(file_path, 'w', encoding='utf-8') as f: + with open(file_path, "w", encoding="utf-8") as f: json.dump(config, f, indent=2) logger.info(f"Mellon config and name mapping saved to {file_path}") @@ -452,10 +598,10 @@ def save_mellon_config(self, file_path): def load_mellon_config(cls, file_path): """ Load a Mellon configuration from a JSON file. - + Args: file_path (str or Path): Path to the JSON file containing Mellon config - + Returns: dict: The loaded combined configuration containing 'module' and 'name_mapping' """ @@ -464,16 +610,14 @@ def load_mellon_config(cls, file_path): if not file_path.exists(): raise FileNotFoundError(f"Config file not found: {file_path}") - with open(file_path, 'r', encoding='utf-8') as f: + with open(file_path, "r", encoding="utf-8") as f: config = json.load(f) logger.info(f"Mellon config loaded from {file_path}") - return config def process_inputs(self, **kwargs): - params_components = {} for comp_name, comp_param in self.config.component_params.items(): logger.debug(f"component: {comp_name}") @@ -486,7 +630,6 @@ def process_inputs(self, **kwargs): if comp: params_components[comp_name] = self._components_manager.get_one(comp["model_id"]) - params_run = {} for inp_name, inp_param in self.config.input_params.items(): logger.debug(f"input: {inp_name}") @@ -509,14 +652,3 @@ def execute(self, **kwargs): self.blocks.loader.update(**params_components) output = self.blocks.run(**params_run, output=return_output_names) return output - - - - - - - - - - - diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py index 2fe15bbbee4a..94887aa2791f 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py @@ -21,11 +21,24 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: - _import_structure["modular_pipeline_presets"] = ["StableDiffusionXLAutoPipeline"] - _import_structure["modular_loader"] = ["StableDiffusionXLModularLoader"] - _import_structure["encoders"] = ["StableDiffusionXLAutoIPAdapterStep", "StableDiffusionXLTextEncoderStep", "StableDiffusionXLAutoVaeEncoderStep"] _import_structure["decoders"] = ["StableDiffusionXLAutoDecodeStep"] - _import_structure["modular_block_mappings"] = ["TEXT2IMAGE_BLOCKS", "IMAGE2IMAGE_BLOCKS", "INPAINT_BLOCKS", "CONTROLNET_BLOCKS", "CONTROLNET_UNION_BLOCKS", "IP_ADAPTER_BLOCKS", "AUTO_BLOCKS", "SDXL_SUPPORTED_BLOCKS"] + _import_structure["encoders"] = [ + "StableDiffusionXLAutoIPAdapterStep", + "StableDiffusionXLAutoVaeEncoderStep", + "StableDiffusionXLTextEncoderStep", + ] + _import_structure["modular_block_mappings"] = [ + "AUTO_BLOCKS", + "CONTROLNET_BLOCKS", + "CONTROLNET_UNION_BLOCKS", + "IMAGE2IMAGE_BLOCKS", + "INPAINT_BLOCKS", + "IP_ADAPTER_BLOCKS", + "SDXL_SUPPORTED_BLOCKS", + "TEXT2IMAGE_BLOCKS", + ] + _import_structure["modular_loader"] = ["StableDiffusionXLModularLoader"] + _import_structure["modular_pipeline_presets"] = ["StableDiffusionXLAutoPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py index 2032a57dcfcc..fd73d8d74943 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py @@ -38,16 +38,12 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - - # TODO(yiyi, aryan): We need another step before text encoder to set the `num_inference_steps` attribute for guider so that # things like when to do guidance and how many conditions to be prepared can be determined. Currently, this is done by # always assuming you want to do guidance in the Guiders. So, negative embeddings are prepared regardless of what the # configuration of guider is. - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, @@ -122,12 +118,11 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -def prepare_latents_img2img(vae, scheduler, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True): - +def prepare_latents_img2img( + vae, scheduler, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True +): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): - raise ValueError( - f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" - ) + raise ValueError(f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}") image = image.to(device=device, dtype=dtype) @@ -162,8 +157,7 @@ def prepare_latents_img2img(vae, scheduler, image, timestep, batch_size, num_ima ) init_latents = [ - retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(batch_size) + retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: @@ -225,29 +219,91 @@ def inputs(self) -> List[InputParam]: @property def intermediates_inputs(self) -> List[str]: return [ - InputParam("prompt_embeds", required=True, type_hint=torch.Tensor, description="Pre-generated text embeddings. Can be generated from text_encoder step."), - InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Pre-generated negative text embeddings. Can be generated from text_encoder step."), - InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor, description="Pre-generated pooled text embeddings. Can be generated from text_encoder step."), - InputParam("negative_pooled_prompt_embeds", description="Pre-generated negative pooled text embeddings. Can be generated from text_encoder step."), - InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Pre-generated image embeddings for IP-Adapter. Can be generated from ip_adapter step."), - InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Pre-generated negative image embeddings for IP-Adapter. Can be generated from ip_adapter step."), + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Pre-generated text embeddings. Can be generated from text_encoder step.", + ), + InputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + description="Pre-generated negative text embeddings. Can be generated from text_encoder step.", + ), + InputParam( + "pooled_prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Pre-generated pooled text embeddings. Can be generated from text_encoder step.", + ), + InputParam( + "negative_pooled_prompt_embeds", + description="Pre-generated negative pooled text embeddings. Can be generated from text_encoder step.", + ), + InputParam( + "ip_adapter_embeds", + type_hint=List[torch.Tensor], + description="Pre-generated image embeddings for IP-Adapter. Can be generated from ip_adapter step.", + ), + InputParam( + "negative_ip_adapter_embeds", + type_hint=List[torch.Tensor], + description="Pre-generated negative image embeddings for IP-Adapter. Can be generated from ip_adapter step.", + ), ] @property def intermediates_outputs(self) -> List[str]: return [ - OutputParam("batch_size", type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt"), - OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs (determined by `prompt_embeds`)"), - OutputParam("prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="text embeddings used to guide the image generation"), - OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative text embeddings used to guide the image generation"), - OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="pooled text embeddings used to guide the image generation"), - OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative pooled text embeddings used to guide the image generation"), - OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], kwargs_type="guider_input_fields", description="image embeddings for IP-Adapter"), - OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], kwargs_type="guider_input_fields", description="negative image embeddings for IP-Adapter"), + OutputParam( + "batch_size", + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt", + ), + OutputParam( + "dtype", + type_hint=torch.dtype, + description="Data type of model tensor inputs (determined by `prompt_embeds`)", + ), + OutputParam( + "prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="guider_input_fields", + description="text embeddings used to guide the image generation", + ), + OutputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="guider_input_fields", + description="negative text embeddings used to guide the image generation", + ), + OutputParam( + "pooled_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="guider_input_fields", + description="pooled text embeddings used to guide the image generation", + ), + OutputParam( + "negative_pooled_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="guider_input_fields", + description="negative pooled text embeddings used to guide the image generation", + ), + OutputParam( + "ip_adapter_embeds", + type_hint=List[torch.Tensor], + kwargs_type="guider_input_fields", + description="image embeddings for IP-Adapter", + ), + OutputParam( + "negative_ip_adapter_embeds", + type_hint=List[torch.Tensor], + kwargs_type="guider_input_fields", + description="negative image embeddings for IP-Adapter", + ), ] def check_inputs(self, components, block_state): - if block_state.prompt_embeds is not None and block_state.negative_prompt_embeds is not None: if block_state.prompt_embeds.shape != block_state.negative_prompt_embeds.shape: raise ValueError( @@ -269,7 +325,9 @@ def check_inputs(self, components, block_state): if block_state.ip_adapter_embeds is not None and not isinstance(block_state.ip_adapter_embeds, list): raise ValueError("`ip_adapter_embeds` must be a list") - if block_state.negative_ip_adapter_embeds is not None and not isinstance(block_state.negative_ip_adapter_embeds, list): + if block_state.negative_ip_adapter_embeds is not None and not isinstance( + block_state.negative_ip_adapter_embeds, list + ): raise ValueError("`negative_ip_adapter_embeds` must be a list") if block_state.ip_adapter_embeds is not None and block_state.negative_ip_adapter_embeds is not None: @@ -292,27 +350,45 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt _, seq_len, _ = block_state.prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) - block_state.prompt_embeds = block_state.prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1) + block_state.prompt_embeds = block_state.prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) if block_state.negative_prompt_embeds is not None: _, seq_len, _ = block_state.negative_prompt_embeds.shape - block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) - block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1) + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat( + 1, block_state.num_images_per_prompt, 1 + ) + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) - block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) - block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1) + block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat( + 1, block_state.num_images_per_prompt, 1 + ) + block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, -1 + ) if block_state.negative_pooled_prompt_embeds is not None: - block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) - block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1) + block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.repeat( + 1, block_state.num_images_per_prompt, 1 + ) + block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, -1 + ) if block_state.ip_adapter_embeds is not None: for i, ip_adapter_embed in enumerate(block_state.ip_adapter_embeds): - block_state.ip_adapter_embeds[i] = torch.cat([ip_adapter_embed] * block_state.num_images_per_prompt, dim=0) + block_state.ip_adapter_embeds[i] = torch.cat( + [ip_adapter_embed] * block_state.num_images_per_prompt, dim=0 + ) if block_state.negative_ip_adapter_embeds is not None: for i, negative_ip_adapter_embed in enumerate(block_state.negative_ip_adapter_embeds): - block_state.negative_ip_adapter_embeds[i] = torch.cat([negative_ip_adapter_embed] * block_state.num_images_per_prompt, dim=0) + block_state.negative_ip_adapter_embeds[i] = torch.cat( + [negative_ip_adapter_embed] * block_state.num_images_per_prompt, dim=0 + ) self.add_block_state(state, block_state) @@ -331,8 +407,8 @@ def expected_components(self) -> List[ComponentSpec]: @property def description(self) -> str: return ( - "Step that sets the timesteps for the scheduler and determines the initial noise level (latent_timestep) for image-to-image/inpainting generation.\n" + \ - "The latent_timestep is calculated from the `strength` parameter - higher strength means starting from a noisier version of the input image." + "Step that sets the timesteps for the scheduler and determines the initial noise level (latent_timestep) for image-to-image/inpainting generation.\n" + + "The latent_timestep is calculated from the `strength` parameter - higher strength means starting from a noisier version of the input image." ) @property @@ -351,15 +427,28 @@ def inputs(self) -> List[InputParam]: @property def intermediates_inputs(self) -> List[str]: return [ - InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt"), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt", + ), ] @property def intermediates_outputs(self) -> List[str]: return [ OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), - OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time"), - OutputParam("latent_timestep", type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image generation") + OutputParam( + "num_inference_steps", + type_hint=int, + description="The number of denoising steps to perform at inference time", + ), + OutputParam( + "latent_timestep", + type_hint=torch.Tensor, + description="The timestep that represents the initial noise level for image-to-image generation", + ), ] # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps with self -> components @@ -402,8 +491,6 @@ def get_timesteps(self, components, num_inference_steps, strength, device, denoi components.scheduler.set_begin_index(t_start) return timesteps, num_inference_steps - - @torch.no_grad() def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) @@ -411,7 +498,11 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt block_state.device = components._execution_device block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( - components.scheduler, block_state.num_inference_steps, block_state.device, block_state.timesteps, block_state.sigmas + components.scheduler, + block_state.num_inference_steps, + block_state.device, + block_state.timesteps, + block_state.sigmas, ) def denoising_value_valid(dnv): @@ -422,19 +513,30 @@ def denoising_value_valid(dnv): block_state.num_inference_steps, block_state.strength, block_state.device, - denoising_start=block_state.denoising_start if denoising_value_valid(block_state.denoising_start) else None, + denoising_start=block_state.denoising_start + if denoising_value_valid(block_state.denoising_start) + else None, + ) + block_state.latent_timestep = block_state.timesteps[:1].repeat( + block_state.batch_size * block_state.num_images_per_prompt ) - block_state.latent_timestep = block_state.timesteps[:1].repeat(block_state.batch_size * block_state.num_images_per_prompt) - if block_state.denoising_end is not None and isinstance(block_state.denoising_end, float) and block_state.denoising_end > 0 and block_state.denoising_end < 1: + if ( + block_state.denoising_end is not None + and isinstance(block_state.denoising_end, float) + and block_state.denoising_end > 0 + and block_state.denoising_end < 1 + ): block_state.discrete_timestep_cutoff = int( round( components.scheduler.config.num_train_timesteps - (block_state.denoising_end * components.scheduler.config.num_train_timesteps) ) ) - block_state.num_inference_steps = len(list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps))) - block_state.timesteps = block_state.timesteps[:block_state.num_inference_steps] + block_state.num_inference_steps = len( + list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps)) + ) + block_state.timesteps = block_state.timesteps[: block_state.num_inference_steps] self.add_block_state(state, block_state) @@ -442,7 +544,6 @@ def denoising_value_valid(dnv): class StableDiffusionXLSetTimestepsStep(PipelineBlock): - model_name = "stable-diffusion-xl" @property @@ -453,9 +554,7 @@ def expected_components(self) -> List[ComponentSpec]: @property def description(self) -> str: - return ( - "Step that sets the scheduler's timesteps for inference" - ) + return "Step that sets the scheduler's timesteps for inference" @property def inputs(self) -> List[InputParam]: @@ -468,9 +567,14 @@ def inputs(self) -> List[InputParam]: @property def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), - OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time")] - + return [ + OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), + OutputParam( + "num_inference_steps", + type_hint=int, + description="The number of denoising steps to perform at inference time", + ), + ] @torch.no_grad() def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: @@ -479,25 +583,35 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt block_state.device = components._execution_device block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( - components.scheduler, block_state.num_inference_steps, block_state.device, block_state.timesteps, block_state.sigmas + components.scheduler, + block_state.num_inference_steps, + block_state.device, + block_state.timesteps, + block_state.sigmas, ) - if block_state.denoising_end is not None and isinstance(block_state.denoising_end, float) and block_state.denoising_end > 0 and block_state.denoising_end < 1: + if ( + block_state.denoising_end is not None + and isinstance(block_state.denoising_end, float) + and block_state.denoising_end > 0 + and block_state.denoising_end < 1 + ): block_state.discrete_timestep_cutoff = int( round( components.scheduler.config.num_train_timesteps - (block_state.denoising_end * components.scheduler.config.num_train_timesteps) ) ) - block_state.num_inference_steps = len(list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps))) - block_state.timesteps = block_state.timesteps[:block_state.num_inference_steps] + block_state.num_inference_steps = len( + list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps)) + ) + block_state.timesteps = block_state.timesteps[: block_state.num_inference_steps] self.add_block_state(state, block_state) return components, state class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): - model_name = "stable-diffusion-xl" @property @@ -508,9 +622,7 @@ def expected_components(self) -> List[ComponentSpec]: @property def description(self) -> str: - return ( - "Step that prepares the latents for the inpainting process" - ) + return "Step that prepares the latents for the inpainting process" @property def inputs(self) -> List[Tuple[str, Any]]: @@ -526,7 +638,7 @@ def inputs(self) -> List[Tuple[str, Any]]: "denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will " "be maximum and the denoising process will run for the full number of iterations specified in " "`num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of " - "`denoising_start` being declared as an integer, the value of `strength` will be ignored." + "`denoising_start` being declared as an integer, the value of `strength` will be ignored.", ), ] @@ -538,51 +650,57 @@ def intermediates_inputs(self) -> List[str]: "batch_size", required=True, type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", ), InputParam( "latent_timestep", required=True, type_hint=torch.Tensor, - description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step." + description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step.", ), InputParam( "image_latents", required=True, type_hint=torch.Tensor, - description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step." + description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step.", ), InputParam( "mask", required=True, type_hint=torch.Tensor, - description="The mask for the inpainting generation. Can be generated in vae_encode step." + description="The mask for the inpainting generation. Can be generated in vae_encode step.", ), InputParam( "masked_image_latents", type_hint=torch.Tensor, - description="The masked image latents for the inpainting generation (only for inpainting-specific unet). Can be generated in vae_encode step." + description="The masked image latents for the inpainting generation (only for inpainting-specific unet). Can be generated in vae_encode step.", ), - InputParam( - "dtype", - type_hint=torch.dtype, - description="The dtype of the model inputs" - ) + InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), ] @property def intermediates_outputs(self) -> List[str]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"), - OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for inpainting generation"), - OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting generation (only for inpainting-specific unet)"), - OutputParam("noise", type_hint=torch.Tensor, description="The noise added to the image latents, used for inpainting generation")] - + return [ + OutputParam( + "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" + ), + OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for inpainting generation"), + OutputParam( + "masked_image_latents", + type_hint=torch.Tensor, + description="The masked image latents to use for the inpainting generation (only for inpainting-specific unet)", + ), + OutputParam( + "noise", + type_hint=torch.Tensor, + description="The noise added to the image latents, used for inpainting generation", + ), + ] # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components # YiYi TODO: update the _encode_vae_image so that we can use #Coped from @staticmethod def _encode_vae_image(components, image: torch.Tensor, generator: torch.Generator): - latents_mean = latents_std = None if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) @@ -733,7 +851,6 @@ def prepare_mask_latents( return mask, masked_image_latents - @torch.no_grad() def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) @@ -744,7 +861,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt block_state.is_strength_max = block_state.strength == 1.0 # for non-inpainting specific unet, we do not need masked_image_latents - if hasattr(components,"unet") and components.unet is not None: + if hasattr(components, "unet") and components.unet is not None: if components.unet.config.in_channels == 4: block_state.masked_image_latents = None @@ -801,9 +918,7 @@ def expected_components(self) -> List[ComponentSpec]: @property def description(self) -> str: - return ( - "Step that prepares the latents for the image-to-image generation process" - ) + return "Step that prepares the latents for the image-to-image generation process" @property def inputs(self) -> List[Tuple[str, Any]]: @@ -817,14 +932,34 @@ def inputs(self) -> List[Tuple[str, Any]]: def intermediates_inputs(self) -> List[InputParam]: return [ InputParam("generator"), - InputParam("latent_timestep", required=True, type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step."), - InputParam("image_latents", required=True, type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step."), - InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."), - InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs")] + InputParam( + "latent_timestep", + required=True, + type_hint=torch.Tensor, + description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step.", + ), + InputParam( + "image_latents", + required=True, + type_hint=torch.Tensor, + description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step.", + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", + ), + InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs"), + ] @property def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process")] + return [ + OutputParam( + "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" + ) + ] @torch.no_grad() def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: @@ -863,9 +998,7 @@ def expected_components(self) -> List[ComponentSpec]: @property def description(self) -> str: - return ( - "Prepare latents step that prepares the latents for the text-to-image generation process" - ) + return "Prepare latents step that prepares the latents for the text-to-image generation process" @property def inputs(self) -> List[InputParam]: @@ -884,26 +1017,19 @@ def intermediates_inputs(self) -> List[InputParam]: "batch_size", required=True, type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", ), - InputParam( - "dtype", - type_hint=torch.dtype, - description="The dtype of the model inputs" - ) + InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), ] @property def intermediates_outputs(self) -> List[OutputParam]: return [ OutputParam( - "latents", - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process" + "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" ) ] - @staticmethod def check_inputs(components, block_state): if ( @@ -918,7 +1044,9 @@ def check_inputs(components, block_state): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with self -> components @staticmethod - def prepare_latents(components, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + def prepare_latents( + components, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None + ): shape = ( batch_size, num_channels_latents, @@ -940,7 +1068,6 @@ def prepare_latents(components, batch_size, num_channels_latents, height, width, latents = latents * components.scheduler.init_noise_sigma return latents - @torch.no_grad() def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) @@ -973,18 +1100,17 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): - model_name = "stable-diffusion-xl" @property def expected_configs(self) -> List[ConfigSpec]: - return [ConfigSpec("requires_aesthetics_score", False),] + return [ + ConfigSpec("requires_aesthetics_score", False), + ] @property def description(self) -> str: - return ( - "Step that prepares the additional conditioning for the image-to-image/inpainting generation process" - ) + return "Step that prepares the additional conditioning for the image-to-image/inpainting generation process" @property def inputs(self) -> List[Tuple[str, Any]]: @@ -1003,16 +1129,43 @@ def inputs(self) -> List[Tuple[str, Any]]: @property def intermediates_inputs(self) -> List[InputParam]: return [ - InputParam("latents", required=True, type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."), - InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor, description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step."), - InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + InputParam( + "pooled_prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step.", + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", + ), ] @property def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The time ids to condition the denoising process"), - OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process"), - OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] + return [ + OutputParam( + "add_time_ids", + type_hint=torch.Tensor, + kwargs_type="guider_input_fields", + description="The time ids to condition the denoising process", + ), + OutputParam( + "negative_add_time_ids", + type_hint=torch.Tensor, + kwargs_type="guider_input_fields", + description="The negative time ids to condition the denoising process", + ), + OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM"), + ] # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids with self -> components @staticmethod @@ -1133,8 +1286,12 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt dtype=block_state.pooled_prompt_embeds.dtype, text_encoder_projection_dim=block_state.text_encoder_projection_dim, ) - block_state.add_time_ids = block_state.add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) - block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) + block_state.add_time_ids = block_state.add_time_ids.repeat( + block_state.batch_size * block_state.num_images_per_prompt, 1 + ).to(device=block_state.device) + block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat( + block_state.batch_size * block_state.num_images_per_prompt, 1 + ).to(device=block_state.device) # Optionally get Guidance Scale Embedding for LCM block_state.timestep_cond = None @@ -1144,7 +1301,9 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt and components.unet.config.time_cond_proj_dim is not None ): # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! - block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat(block_state.batch_size * block_state.num_images_per_prompt) + block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat( + block_state.batch_size * block_state.num_images_per_prompt + ) block_state.timestep_cond = self.get_guidance_scale_embedding( block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim ).to(device=block_state.device, dtype=block_state.latents.dtype) @@ -1158,9 +1317,7 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): @property def description(self) -> str: - return ( - "Step that prepares the additional conditioning for the text-to-image generation process" - ) + return "Step that prepares the additional conditioning for the text-to-image generation process" @property def inputs(self) -> List[Tuple[str, Any]]: @@ -1181,27 +1338,39 @@ def intermediates_inputs(self) -> List[InputParam]: "latents", required=True, type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", ), InputParam( "pooled_prompt_embeds", required=True, type_hint=torch.Tensor, - description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step." + description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step.", ), InputParam( "batch_size", required=True, type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", ), ] @property def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The time ids to condition the denoising process"), - OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process"), - OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] + return [ + OutputParam( + "add_time_ids", + type_hint=torch.Tensor, + kwargs_type="guider_input_fields", + description="The time ids to condition the denoising process", + ), + OutputParam( + "negative_add_time_ids", + type_hint=torch.Tensor, + kwargs_type="guider_input_fields", + description="The negative time ids to condition the denoising process", + ), + OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM"), + ] # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids with self -> components @staticmethod @@ -1289,8 +1458,12 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt else: block_state.negative_add_time_ids = block_state.add_time_ids - block_state.add_time_ids = block_state.add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) - block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) + block_state.add_time_ids = block_state.add_time_ids.repeat( + block_state.batch_size * block_state.num_images_per_prompt, 1 + ).to(device=block_state.device) + block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat( + block_state.batch_size * block_state.num_images_per_prompt, 1 + ).to(device=block_state.device) # Optionally get Guidance Scale Embedding for LCM block_state.timestep_cond = None @@ -1300,7 +1473,9 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt and components.unet.config.time_cond_proj_dim is not None ): # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! - block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat(block_state.batch_size * block_state.num_images_per_prompt) + block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat( + block_state.batch_size * block_state.num_images_per_prompt + ) block_state.timestep_cond = self.get_guidance_scale_embedding( block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim ).to(device=block_state.device, dtype=block_state.latents.dtype) @@ -1310,14 +1485,18 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt class StableDiffusionXLControlNetInputStep(PipelineBlock): - model_name = "stable-diffusion-xl" @property def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("controlnet", ControlNetModel), - ComponentSpec("control_image_processor", VaeImageProcessor, config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), default_creation_method="from_config"), + ComponentSpec( + "control_image_processor", + VaeImageProcessor, + config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), + default_creation_method="from_config", + ), ] @property @@ -1342,24 +1521,24 @@ def intermediates_inputs(self) -> List[str]: "latents", required=True, type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", ), InputParam( "batch_size", required=True, type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", ), InputParam( "timesteps", required=True, type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", ), InputParam( "crops_coords", type_hint=Optional[Tuple[int]], - description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." + description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.", ), ] @@ -1367,15 +1546,19 @@ def intermediates_inputs(self) -> List[str]: def intermediates_outputs(self) -> List[OutputParam]: return [ OutputParam("controlnet_cond", type_hint=torch.Tensor, description="The processed control image"), - OutputParam("control_guidance_start", type_hint=List[float], description="The controlnet guidance start values"), - OutputParam("control_guidance_end", type_hint=List[float], description="The controlnet guidance end values"), - OutputParam("conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values"), + OutputParam( + "control_guidance_start", type_hint=List[float], description="The controlnet guidance start values" + ), + OutputParam( + "control_guidance_end", type_hint=List[float], description="The controlnet guidance end values" + ), + OutputParam( + "conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values" + ), OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"), OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"), ] - - # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image # 1. return image without apply any guidance # 2. add crops_coords and resize_mode to preprocess() @@ -1392,9 +1575,13 @@ def prepare_control_image( crops_coords=None, ): if crops_coords is not None: - image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) + image = components.control_image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill" + ).to(dtype=torch.float32) else: - image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image = components.control_image_processor.preprocess(image, height=height, width=width).to( + dtype=torch.float32 + ) image_batch_size = image.shape[0] if image_batch_size == 1: @@ -1407,11 +1594,8 @@ def prepare_control_image( image = image.to(device=device, dtype=dtype) return image - - @torch.no_grad() def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) # (1) prepare controlnet inputs @@ -1424,11 +1608,21 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt # (1.1) # control_guidance_start/control_guidance_end (align format) - if not isinstance(block_state.control_guidance_start, list) and isinstance(block_state.control_guidance_end, list): - block_state.control_guidance_start = len(block_state.control_guidance_end) * [block_state.control_guidance_start] - elif not isinstance(block_state.control_guidance_end, list) and isinstance(block_state.control_guidance_start, list): - block_state.control_guidance_end = len(block_state.control_guidance_start) * [block_state.control_guidance_end] - elif not isinstance(block_state.control_guidance_start, list) and not isinstance(block_state.control_guidance_end, list): + if not isinstance(block_state.control_guidance_start, list) and isinstance( + block_state.control_guidance_end, list + ): + block_state.control_guidance_start = len(block_state.control_guidance_end) * [ + block_state.control_guidance_start + ] + elif not isinstance(block_state.control_guidance_end, list) and isinstance( + block_state.control_guidance_start, list + ): + block_state.control_guidance_end = len(block_state.control_guidance_start) * [ + block_state.control_guidance_end + ] + elif not isinstance(block_state.control_guidance_start, list) and not isinstance( + block_state.control_guidance_end, list + ): mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 block_state.control_guidance_start, block_state.control_guidance_end = ( mult * [block_state.control_guidance_start], @@ -1437,8 +1631,12 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt # (1.2) # controlnet_conditioning_scale (align format) - if isinstance(controlnet, MultiControlNetModel) and isinstance(block_state.controlnet_conditioning_scale, float): - block_state.controlnet_conditioning_scale = [block_state.controlnet_conditioning_scale] * len(controlnet.nets) + if isinstance(controlnet, MultiControlNetModel) and isinstance( + block_state.controlnet_conditioning_scale, float + ): + block_state.controlnet_conditioning_scale = [block_state.controlnet_conditioning_scale] * len( + controlnet.nets + ) # (1.3) # global_pool_conditions @@ -1500,8 +1698,6 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt block_state.controlnet_cond = block_state.control_image block_state.conditioning_scale = block_state.controlnet_conditioning_scale - - self.add_block_state(state, block_state) return components, state @@ -1514,7 +1710,12 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock): def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("controlnet", ControlNetUnionModel), - ComponentSpec("control_image_processor", VaeImageProcessor, config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), default_creation_method="from_config"), + ComponentSpec( + "control_image_processor", + VaeImageProcessor, + config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), + default_creation_method="from_config", + ), ] @property @@ -1540,30 +1741,30 @@ def intermediates_inputs(self) -> List[InputParam]: "latents", required=True, type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Used to determine the shape of the control images. Can be generated in prepare_latent step." + description="The initial latents to use for the denoising process. Used to determine the shape of the control images. Can be generated in prepare_latent step.", ), InputParam( "batch_size", required=True, type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", ), - InputParam( + InputParam( "dtype", required=True, type_hint=torch.dtype, - description="The dtype of model tensor inputs. Can be generated in input step." + description="The dtype of model tensor inputs. Can be generated in input step.", ), InputParam( "timesteps", required=True, type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Needed to determine `controlnet_keep`. Can be generated in set_timesteps step." + description="The timesteps to use for the denoising process. Needed to determine `controlnet_keep`. Can be generated in set_timesteps step.", ), InputParam( "crops_coords", type_hint=Optional[Tuple[int]], - description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." + description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.", ), ] @@ -1571,11 +1772,23 @@ def intermediates_inputs(self) -> List[InputParam]: def intermediates_outputs(self) -> List[OutputParam]: return [ OutputParam("controlnet_cond", type_hint=List[torch.Tensor], description="The processed control images"), - OutputParam("control_type_idx", type_hint=List[int], description="The control mode indices", kwargs_type="controlnet_kwargs"), - OutputParam("control_type", type_hint=torch.Tensor, description="The control type tensor that specifies which control type is active", kwargs_type="controlnet_kwargs"), + OutputParam( + "control_type_idx", + type_hint=List[int], + description="The control mode indices", + kwargs_type="controlnet_kwargs", + ), + OutputParam( + "control_type", + type_hint=torch.Tensor, + description="The control type tensor that specifies which control type is active", + kwargs_type="controlnet_kwargs", + ), OutputParam("control_guidance_start", type_hint=float, description="The controlnet guidance start value"), OutputParam("control_guidance_end", type_hint=float, description="The controlnet guidance end value"), - OutputParam("conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values"), + OutputParam( + "conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values" + ), OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"), OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"), ] @@ -1596,9 +1809,13 @@ def prepare_control_image( crops_coords=None, ): if crops_coords is not None: - image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) + image = components.control_image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill" + ).to(dtype=torch.float32) else: - image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image = components.control_image_processor.preprocess(image, height=height, width=width).to( + dtype=torch.float32 + ) image_batch_size = image.shape[0] if image_batch_size == 1: @@ -1613,7 +1830,6 @@ def prepare_control_image( @torch.no_grad() def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) controlnet = unwrap_module(components.controlnet) @@ -1625,12 +1841,19 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt block_state.height = block_state.height * components.vae_scale_factor block_state.width = block_state.width * components.vae_scale_factor - # control_guidance_start/control_guidance_end (align format) - if not isinstance(block_state.control_guidance_start, list) and isinstance(block_state.control_guidance_end, list): - block_state.control_guidance_start = len(block_state.control_guidance_end) * [block_state.control_guidance_start] - elif not isinstance(block_state.control_guidance_end, list) and isinstance(block_state.control_guidance_start, list): - block_state.control_guidance_end = len(block_state.control_guidance_start) * [block_state.control_guidance_end] + if not isinstance(block_state.control_guidance_start, list) and isinstance( + block_state.control_guidance_end, list + ): + block_state.control_guidance_start = len(block_state.control_guidance_end) * [ + block_state.control_guidance_start + ] + elif not isinstance(block_state.control_guidance_end, list) and isinstance( + block_state.control_guidance_start, list + ): + block_state.control_guidance_end = len(block_state.control_guidance_start) * [ + block_state.control_guidance_end + ] # guess_mode block_state.global_pool_conditions = controlnet.config.global_pool_conditions @@ -1677,7 +1900,10 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt for i in range(len(block_state.timesteps)): block_state.controlnet_keep.append( 1.0 - - float(i / len(block_state.timesteps) < block_state.control_guidance_start or (i + 1) / len(block_state.timesteps) > block_state.control_guidance_end) + - float( + i / len(block_state.timesteps) < block_state.control_guidance_start + or (i + 1) / len(block_state.timesteps) > block_state.control_guidance_end + ) ) block_state.control_type_idx = block_state.control_mode block_state.controlnet_cond = block_state.control_image @@ -1689,77 +1915,107 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt class StableDiffusionXLControlNetAutoInput(AutoPipelineBlocks): - block_classes = [StableDiffusionXLControlNetUnionInputStep, StableDiffusionXLControlNetInputStep] block_names = ["controlnet_union", "controlnet"] block_trigger_inputs = ["control_mode", "control_image"] @property def description(self): - return "Controlnet Input step that prepare the controlnet input.\n" + \ - "This is an auto pipeline block that works for both controlnet and controlnet_union.\n" + \ - " - `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided.\n" + \ - " - `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided." + return ( + "Controlnet Input step that prepare the controlnet input.\n" + + "This is an auto pipeline block that works for both controlnet and controlnet_union.\n" + + " - `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided.\n" + + " - `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided." + ) # Before denoise class StableDiffusionXLBeforeDenoiseStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLInputStep, StableDiffusionXLSetTimestepsStep, StableDiffusionXLPrepareLatentsStep, StableDiffusionXLPrepareAdditionalConditioningStep, StableDiffusionXLControlNetAutoInput] + block_classes = [ + StableDiffusionXLInputStep, + StableDiffusionXLSetTimestepsStep, + StableDiffusionXLPrepareLatentsStep, + StableDiffusionXLPrepareAdditionalConditioningStep, + StableDiffusionXLControlNetAutoInput, + ] block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond", "controlnet_input"] @property def description(self): - return "Before denoise step that prepare the inputs for the denoise step.\n" + \ - "This is a sequential pipeline blocks:\n" + \ - " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ - " - `StableDiffusionXLSetTimestepsStep` is used to set the timesteps\n" + \ - " - `StableDiffusionXLPrepareLatentsStep` is used to prepare the latents\n" + \ - " - `StableDiffusionXLPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + \ - " - `StableDiffusionXLControlNetAutoInput` is used to prepare the controlnet input" + return ( + "Before denoise step that prepare the inputs for the denoise step.\n" + + "This is a sequential pipeline blocks:\n" + + " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + + " - `StableDiffusionXLSetTimestepsStep` is used to set the timesteps\n" + + " - `StableDiffusionXLPrepareLatentsStep` is used to prepare the latents\n" + + " - `StableDiffusionXLPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + + " - `StableDiffusionXLControlNetAutoInput` is used to prepare the controlnet input" + ) class StableDiffusionXLImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLInputStep, StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLImg2ImgPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, StableDiffusionXLControlNetAutoInput] + block_classes = [ + StableDiffusionXLInputStep, + StableDiffusionXLImg2ImgSetTimestepsStep, + StableDiffusionXLImg2ImgPrepareLatentsStep, + StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, + StableDiffusionXLControlNetAutoInput, + ] block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond", "controlnet_input"] @property def description(self): - return "Before denoise step that prepare the inputs for the denoise step for img2img task.\n" + \ - "This is a sequential pipeline blocks:\n" + \ - " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ - " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" + \ - " - `StableDiffusionXLImg2ImgPrepareLatentsStep` is used to prepare the latents\n" + \ - " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + \ - " - `StableDiffusionXLControlNetAutoInput` is used to prepare the controlnet input" + return ( + "Before denoise step that prepare the inputs for the denoise step for img2img task.\n" + + "This is a sequential pipeline blocks:\n" + + " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + + " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" + + " - `StableDiffusionXLImg2ImgPrepareLatentsStep` is used to prepare the latents\n" + + " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + + " - `StableDiffusionXLControlNetAutoInput` is used to prepare the controlnet input" + ) class StableDiffusionXLInpaintBeforeDenoiseStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLInputStep, StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLInpaintPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, StableDiffusionXLControlNetAutoInput] + block_classes = [ + StableDiffusionXLInputStep, + StableDiffusionXLImg2ImgSetTimestepsStep, + StableDiffusionXLInpaintPrepareLatentsStep, + StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, + StableDiffusionXLControlNetAutoInput, + ] block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond", "controlnet_input"] @property def description(self): - return "Before denoise step that prepare the inputs for the denoise step for inpainting task.\n" + \ - "This is a sequential pipeline blocks:\n" + \ - " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ - " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" + \ - " - `StableDiffusionXLInpaintPrepareLatentsStep` is used to prepare the latents\n" + \ - " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + \ - " - `StableDiffusionXLControlNetAutoInput` is used to prepare the controlnet input" + return ( + "Before denoise step that prepare the inputs for the denoise step for inpainting task.\n" + + "This is a sequential pipeline blocks:\n" + + " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + + " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" + + " - `StableDiffusionXLInpaintPrepareLatentsStep` is used to prepare the latents\n" + + " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + + " - `StableDiffusionXLControlNetAutoInput` is used to prepare the controlnet input" + ) class StableDiffusionXLAutoBeforeDenoiseStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLInpaintBeforeDenoiseStep, StableDiffusionXLImg2ImgBeforeDenoiseStep, StableDiffusionXLBeforeDenoiseStep] + block_classes = [ + StableDiffusionXLInpaintBeforeDenoiseStep, + StableDiffusionXLImg2ImgBeforeDenoiseStep, + StableDiffusionXLBeforeDenoiseStep, + ] block_names = ["inpaint", "img2img", "text2img"] block_trigger_inputs = ["mask", "image_latents", None] @property def description(self): - return "Before denoise step that prepare the inputs for the denoise step.\n" + \ - "This is an auto pipeline block that works for text2img, img2img and inpainting tasks as well as controlnet, controlnet_union.\n" + \ - " - `StableDiffusionXLInpaintBeforeDenoiseStep` (inpaint) is used when both `mask` and `image_latents` are provided.\n" + \ - " - `StableDiffusionXLImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n" + \ - " - `StableDiffusionXLBeforeDenoiseStep` (text2img) is used when both `image_latents` and `mask` are not provided.\n" + \ - " - `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided.\n" + \ - " - `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided." - + return ( + "Before denoise step that prepare the inputs for the denoise step.\n" + + "This is an auto pipeline block that works for text2img, img2img and inpainting tasks as well as controlnet, controlnet_union.\n" + + " - `StableDiffusionXLInpaintBeforeDenoiseStep` (inpaint) is used when both `mask` and `image_latents` are provided.\n" + + " - `StableDiffusionXLImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n" + + " - `StableDiffusionXLBeforeDenoiseStep` (text2img) is used when both `image_latents` and `mask` are not provided.\n" + + " - `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided.\n" + + " - `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided." + ) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py index 3a4e141775f5..96397a5f7648 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py @@ -35,10 +35,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - class StableDiffusionXLDecodeStep(PipelineBlock): - model_name = "stable-diffusion-xl" @property @@ -49,7 +46,8 @@ def expected_components(self) -> List[ComponentSpec]: "image_processor", VaeImageProcessor, config=FrozenDict({"vae_scale_factor": 8}), - default_creation_method="from_config"), + default_creation_method="from_config", + ), ] @property @@ -64,11 +62,24 @@ def inputs(self) -> List[Tuple[str, Any]]: @property def intermediates_inputs(self) -> List[str]: - return [InputParam("latents", required=True, type_hint=torch.Tensor, description="The denoised latents from the denoising step")] + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The denoised latents from the denoising step", + ) + ] @property def intermediates_outputs(self) -> List[str]: - return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array")] + return [ + OutputParam( + "images", + type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], + description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array", + ) + ] # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae with self -> components @staticmethod @@ -121,7 +132,9 @@ def __call__(self, components, state: PipelineState) -> PipelineState: block_state.latents_std = ( torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) - latents = latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean + latents = ( + latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean + ) else: latents = latents / components.vae.config.scaling_factor @@ -137,7 +150,9 @@ def __call__(self, components, state: PipelineState) -> PipelineState: if hasattr(components, "watermark") and components.watermark is not None: block_state.images = components.watermark.apply_watermark(block_state.images) - block_state.images = components.image_processor.postprocess(block_state.images, output_type=block_state.output_type) + block_state.images = components.image_processor.postprocess( + block_state.images, output_type=block_state.output_type + ) self.add_block_state(state, block_state) @@ -149,8 +164,10 @@ class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock): @property def description(self) -> str: - return "A post-processing step that overlays the mask on the image (inpainting task only).\n" + \ - "only needed when you are using the `padding_mask_crop` option when pre-processing the image and mask" + return ( + "A post-processing step that overlays the mask on the image (inpainting task only).\n" + + "only needed when you are using the `padding_mask_crop` option when pre-processing the image and mask" + ) @property def inputs(self) -> List[Tuple[str, Any]]: @@ -163,37 +180,59 @@ def inputs(self) -> List[Tuple[str, Any]]: @property def intermediates_inputs(self) -> List[str]: return [ - InputParam("images", required=True, type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images from the decode step"), - InputParam("crops_coords", required=True, type_hint=Tuple[int, int], description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.") + InputParam( + "images", + required=True, + type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], + description="The generated images from the decode step", + ), + InputParam( + "crops_coords", + required=True, + type_hint=Tuple[int, int], + description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.", + ), ] @property def intermediates_outputs(self) -> List[str]: - return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images with the mask overlayed")] + return [ + OutputParam( + "images", + type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], + description="The generated images with the mask overlayed", + ) + ] @torch.no_grad() def __call__(self, components, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) if block_state.padding_mask_crop is not None and block_state.crops_coords is not None: - block_state.images = [components.image_processor.apply_overlay(block_state.mask_image, block_state.image, i, block_state.crops_coords) for i in block_state.images] + block_state.images = [ + components.image_processor.apply_overlay( + block_state.mask_image, block_state.image, i, block_state.crops_coords + ) + for i in block_state.images + ] self.add_block_state(state, block_state) return components, state - class StableDiffusionXLInpaintDecodeStep(SequentialPipelineBlocks): block_classes = [StableDiffusionXLDecodeStep, StableDiffusionXLInpaintOverlayMaskStep] block_names = ["decode", "mask_overlay"] @property def description(self): - return "Inpaint decode step that decode the denoised latents into images outputs.\n" + \ - "This is a sequential pipeline blocks:\n" + \ - " - `StableDiffusionXLDecodeStep` is used to decode the denoised latents into images\n" + \ - " - `StableDiffusionXLInpaintOverlayMaskStep` is used to overlay the mask on the image" + return ( + "Inpaint decode step that decode the denoised latents into images outputs.\n" + + "This is a sequential pipeline blocks:\n" + + " - `StableDiffusionXLDecodeStep` is used to decode the denoised latents into images\n" + + " - `StableDiffusionXLInpaintOverlayMaskStep` is used to overlay the mask on the image" + ) class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks): @@ -203,9 +242,9 @@ class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks): @property def description(self): - return "Decode step that decode the denoised latents into images outputs.\n" + \ - "This is an auto pipeline block that works for inpainting and non-inpainting tasks.\n" + \ - " - `StableDiffusionXLInpaintDecodeStep` (inpaint) is used when `padding_mask_crop` is provided.\n" + \ - " - `StableDiffusionXLDecodeStep` (non-inpaint) is used when `padding_mask_crop` is not provided." - - + return ( + "Decode step that decode the denoised latents into images outputs.\n" + + "This is an auto pipeline block that works for inpainting and non-inpainting tasks.\n" + + " - `StableDiffusionXLInpaintDecodeStep` (inpaint) is used when `padding_mask_crop` is provided.\n" + + " - `StableDiffusionXLDecodeStep` (non-inpaint) is used when `padding_mask_crop` is not provided." + ) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py index 564665110006..9fa439d24420 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py @@ -36,11 +36,9 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name - # YiYi experimenting composible denoise loop # loop step (1): prepare latent input for denoiser class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock): - model_name = "stable-diffusion-xl" @property @@ -53,7 +51,6 @@ def expected_components(self) -> List[ComponentSpec]: def description(self) -> str: return "step within the denoising loop that prepare the latent input for the denoiser. This block should be used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)" - @property def intermediates_inputs(self) -> List[str]: return [ @@ -61,20 +58,19 @@ def intermediates_inputs(self) -> List[str]: "latents", required=True, type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", ), ] @torch.no_grad() def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): - block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) return components, block_state + # loop step (1): prepare latent input for denoiser (with inpainting) class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock): - model_name = "stable-diffusion-xl" @property @@ -88,7 +84,6 @@ def expected_components(self) -> List[ComponentSpec]: def description(self) -> str: return "step within the denoising loop that prepare the latent input for the denoiser (for inpainting workflow only). This block should be used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object" - @property def intermediates_inputs(self) -> List[str]: return [ @@ -96,24 +91,22 @@ def intermediates_inputs(self) -> List[str]: "latents", required=True, type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", ), InputParam( "mask", type_hint=Optional[torch.Tensor], - description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." + description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step.", ), InputParam( "masked_image_latents", type_hint=Optional[torch.Tensor], - description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." + description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step.", ), ] - @staticmethod def check_inputs(components, block_state): - num_channels_unet = components.num_channels_unet if num_channels_unet == 9: # default case for runwayml/stable-diffusion-inpainting @@ -127,25 +120,25 @@ def check_inputs(components, block_state): f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of" " `components.unet` or your `mask_image` or `image` input." ) @torch.no_grad() def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): - self.check_inputs(components, block_state) block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) if components.num_channels_unet == 9: - block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) - + block_state.scaled_latents = torch.cat( + [block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1 + ) return components, block_state + # loop step (2): denoise the latents with guidance class StableDiffusionXLLoopDenoiser(PipelineBlock): - model_name = "stable-diffusion-xl" @property @@ -155,15 +148,14 @@ def expected_components(self) -> List[ComponentSpec]: "guider", ClassifierFreeGuidance, config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config"), + default_creation_method="from_config", + ), ComponentSpec("unet", UNet2DConditionModel), ] @property def description(self) -> str: - return ( - "Step within the denoising loop that denoise the latents with guidance. This block should be used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)" - ) + return "Step within the denoising loop that denoise the latents with guidance. This block should be used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)" @property def inputs(self) -> List[Tuple[str, Any]]: @@ -178,12 +170,12 @@ def intermediates_inputs(self) -> List[str]: "num_inference_steps", required=True, type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", ), InputParam( "timestep_cond", type_hint=Optional[torch.Tensor], - description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." + description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step.", ), InputParam( kwargs_type="guider_input_fields", @@ -194,25 +186,23 @@ def intermediates_inputs(self) -> List[str]: "pooled_prompt_embeds/negative_pooled_prompt_embeds, " "and ip_adapter_embeds/negative_ip_adapter_embeds (optional)." "please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" - ) + ), ), - ] - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int) -> PipelineState: - + def __call__( + self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int + ) -> PipelineState: # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) - guider_input_fields ={ + guider_input_fields = { "prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"), "time_ids": ("add_time_ids", "negative_add_time_ids"), "text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), "image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"), } - components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) # Prepare mini‐batches according to guidance method and `guider_input_fields` @@ -226,7 +216,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, block_state: Bloc for guider_state_batch in guider_state: components.guider.prepare_models(components.unet) cond_kwargs = guider_state_batch.as_dict() - cond_kwargs = {k:v for k,v in cond_kwargs.items() if k in guider_input_fields} + cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields} prompt_embeds = cond_kwargs.pop("prompt_embeds") # Predict the noise residual @@ -247,9 +237,9 @@ def __call__(self, components: StableDiffusionXLModularLoader, block_state: Bloc return components, block_state + # loop step (2): denoise the latents with guidance (with controlnet) class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock): - model_name = "stable-diffusion-xl" @property @@ -259,7 +249,8 @@ def expected_components(self) -> List[ComponentSpec]: "guider", ClassifierFreeGuidance, config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config"), + default_creation_method="from_config", + ), ComponentSpec("unet", UNet2DConditionModel), ComponentSpec("controlnet", ControlNetModel), ] @@ -281,35 +272,35 @@ def intermediates_inputs(self) -> List[str]: "controlnet_cond", required=True, type_hint=torch.Tensor, - description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step.", ), InputParam( "conditioning_scale", type_hint=float, - description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step.", ), InputParam( "guess_mode", required=True, type_hint=bool, - description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step.", ), InputParam( "controlnet_keep", required=True, type_hint=List[float], - description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step.", ), InputParam( "timestep_cond", type_hint=Optional[torch.Tensor], - description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" + description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step", ), InputParam( "num_inference_steps", required=True, type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", ), InputParam( kwargs_type="guider_input_fields", @@ -320,20 +311,19 @@ def intermediates_inputs(self) -> List[str]: "pooled_prompt_embeds/negative_pooled_prompt_embeds, " "and ip_adapter_embeds/negative_ip_adapter_embeds (optional)." "please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" - ) + ), ), InputParam( kwargs_type="controlnet_kwargs", description=( "additional kwargs for controlnet (e.g. control_type_idx and control_type from the controlnet union input step )" "please add `kwargs_type=controlnet_kwargs` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" - ) - ) + ), + ), ] @staticmethod def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): - accepted_kwargs = set(inspect.signature(func).parameters.keys()) extra_kwargs = {} for key, value in kwargs.items(): @@ -342,25 +332,26 @@ def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): return extra_kwargs - @torch.no_grad() def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): - - extra_controlnet_kwargs = self.prepare_extra_kwargs(components.controlnet.forward, **block_state.controlnet_kwargs) + extra_controlnet_kwargs = self.prepare_extra_kwargs( + components.controlnet.forward, **block_state.controlnet_kwargs + ) # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) - guider_input_fields ={ + guider_input_fields = { "prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"), "time_ids": ("add_time_ids", "negative_add_time_ids"), "text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), "image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"), } - # cond_scale for the timestep (controlnet input) if isinstance(block_state.controlnet_keep[i], list): - block_state.cond_scale = [c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i])] + block_state.cond_scale = [ + c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i]) + ] else: controlnet_cond_scale = block_state.conditioning_scale if isinstance(controlnet_cond_scale, list): @@ -374,7 +365,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, block_state: Bloc # guided denoiser step components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) - # Prepare mini‐batches according to guidance method and `guider_input_fields` + # Prepare mini‐batches according to guidance method and `guider_input_fields` # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds. # e.g. for CFG, we prepare two batches: one for uncond, one for cond # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds @@ -442,9 +433,9 @@ def __call__(self, components: StableDiffusionXLModularLoader, block_state: Bloc return components, block_state + # loop step (3): scheduler step to update latents class StableDiffusionXLLoopAfterDenoiser(PipelineBlock): - model_name = "stable-diffusion-xl" @property @@ -473,10 +464,9 @@ def intermediates_inputs(self) -> List[str]: def intermediates_outputs(self) -> List[OutputParam]: return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - #YiYi TODO: move this out of here + # YiYi TODO: move this out of here @staticmethod def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): - accepted_kwargs = set(inspect.signature(func).parameters.keys()) extra_kwargs = {} for key, value in kwargs.items(): @@ -485,17 +475,23 @@ def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): return extra_kwargs - @torch.no_grad() def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): - # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) - + block_state.extra_step_kwargs = self.prepare_extra_kwargs( + components.scheduler.step, generator=block_state.generator, eta=block_state.eta + ) # Perform scheduler step using the predicted output block_state.latents_dtype = block_state.latents.dtype - block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **block_state.scheduler_step_kwargs, return_dict=False)[0] + block_state.latents = components.scheduler.step( + block_state.noise_pred, + t, + block_state.latents, + **block_state.extra_step_kwargs, + **block_state.scheduler_step_kwargs, + return_dict=False, + )[0] if block_state.latents.dtype != block_state.latents_dtype: if torch.backends.mps.is_available(): @@ -504,9 +500,9 @@ def __call__(self, components: StableDiffusionXLModularLoader, block_state: Bloc return components, block_state + # loop step (3): scheduler step to update latents (with inpainting) class StableDiffusionXLInpaintLoopAfterDenoiser(PipelineBlock): - model_name = "stable-diffusion-xl" @property @@ -534,22 +530,22 @@ def intermediates_inputs(self) -> List[str]: "timesteps", required=True, type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", ), InputParam( "mask", type_hint=Optional[torch.Tensor], - description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." + description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step.", ), InputParam( "noise", type_hint=Optional[torch.Tensor], - description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." + description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step.", ), InputParam( "image_latents", type_hint=Optional[torch.Tensor], - description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." + description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step.", ), ] @@ -559,7 +555,6 @@ def intermediates_outputs(self) -> List[OutputParam]: @staticmethod def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): - accepted_kwargs = set(inspect.signature(func).parameters.keys()) extra_kwargs = {} for key, value in kwargs.items(): @@ -579,16 +574,23 @@ def check_inputs(self, components, block_state): @torch.no_grad() def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): - self.check_inputs(components, block_state) # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) - + block_state.extra_step_kwargs = self.prepare_extra_kwargs( + components.scheduler.step, generator=block_state.generator, eta=block_state.eta + ) # Perform scheduler step using the predicted output block_state.latents_dtype = block_state.latents.dtype - block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **block_state.scheduler_step_kwargs, return_dict=False)[0] + block_state.latents = components.scheduler.step( + block_state.noise_pred, + t, + block_state.latents, + **block_state.extra_step_kwargs, + **block_state.scheduler_step_kwargs, + return_dict=False, + )[0] if block_state.latents.dtype != block_state.latents_dtype: if torch.backends.mps.is_available(): @@ -604,23 +606,20 @@ def __call__(self, components: StableDiffusionXLModularLoader, block_state: Bloc block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) ) - block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents - - + block_state.latents = ( + 1 - block_state.mask + ) * block_state.init_latents_proper + block_state.mask * block_state.latents return components, block_state # the loop wrapper that iterates over the timesteps class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks): - model_name = "stable-diffusion-xl" @property def description(self) -> str: - return ( - "Pipeline block that iteratively denoise the latents over `timesteps`. The specific steps with each iteration can be customized with `blocks` attributes" - ) + return "Pipeline block that iteratively denoise the latents over `timesteps`. The specific steps with each iteration can be customized with `blocks` attributes" @property def loop_expected_components(self) -> List[ComponentSpec]: @@ -629,7 +628,8 @@ def loop_expected_components(self) -> List[ComponentSpec]: "guider", ClassifierFreeGuidance, config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config"), + default_creation_method="from_config", + ), ComponentSpec("scheduler", EulerDiscreteScheduler), ComponentSpec("unet", UNet2DConditionModel), ] @@ -641,17 +641,16 @@ def loop_intermediates_inputs(self) -> List[InputParam]: "timesteps", required=True, type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", ), InputParam( "num_inference_steps", required=True, type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", ), ] - @torch.no_grad() def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) @@ -662,12 +661,16 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt else: components.guider.enable() - block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) + block_state.num_warmup_steps = max( + len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0 + ) with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: for i, t in enumerate(block_state.timesteps): components, block_state = self.loop_step(components, block_state, i=i, t=t) - if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): + if i == len(block_state.timesteps) - 1 or ( + (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0 + ): progress_bar.update() self.add_block_state(state, block_state) @@ -677,7 +680,11 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt # composing the denoising loops class StableDiffusionXLDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): - block_classes = [StableDiffusionXLLoopBeforeDenoiser, StableDiffusionXLLoopDenoiser, StableDiffusionXLLoopAfterDenoiser] + block_classes = [ + StableDiffusionXLLoopBeforeDenoiser, + StableDiffusionXLLoopDenoiser, + StableDiffusionXLLoopAfterDenoiser, + ] block_names = ["before_denoiser", "denoiser", "after_denoiser"] @property @@ -691,10 +698,16 @@ def description(self) -> str: " - `StableDiffusionXLLoopAfterDenoiser`\n" ) + # control_cond class StableDiffusionXLControlNetDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): - block_classes = [StableDiffusionXLLoopBeforeDenoiser, StableDiffusionXLControlNetLoopDenoiser, StableDiffusionXLLoopAfterDenoiser] + block_classes = [ + StableDiffusionXLLoopBeforeDenoiser, + StableDiffusionXLControlNetLoopDenoiser, + StableDiffusionXLLoopAfterDenoiser, + ] block_names = ["before_denoiser", "denoiser", "after_denoiser"] + @property def description(self) -> str: return ( @@ -706,10 +719,16 @@ def description(self) -> str: " - `StableDiffusionXLLoopAfterDenoiser`\n" ) + # mask class StableDiffusionXLInpaintDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): - block_classes = [StableDiffusionXLInpaintLoopBeforeDenoiser, StableDiffusionXLLoopDenoiser, StableDiffusionXLInpaintLoopAfterDenoiser] + block_classes = [ + StableDiffusionXLInpaintLoopBeforeDenoiser, + StableDiffusionXLLoopDenoiser, + StableDiffusionXLInpaintLoopAfterDenoiser, + ] block_names = ["before_denoiser", "denoiser", "after_denoiser"] + @property def description(self) -> str: return ( @@ -720,10 +739,17 @@ def description(self) -> str: " - `StableDiffusionXLLoopDenoiser`\n" " - `StableDiffusionXLInpaintLoopAfterDenoiser`\n" ) + + # control_cond + mask class StableDiffusionXLInpaintControlNetDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): - block_classes = [StableDiffusionXLInpaintLoopBeforeDenoiser, StableDiffusionXLControlNetLoopDenoiser, StableDiffusionXLInpaintLoopAfterDenoiser] + block_classes = [ + StableDiffusionXLInpaintLoopBeforeDenoiser, + StableDiffusionXLControlNetLoopDenoiser, + StableDiffusionXLInpaintLoopAfterDenoiser, + ] block_names = ["before_denoiser", "denoiser", "after_denoiser"] + @property def description(self) -> str: return ( @@ -751,6 +777,7 @@ def description(self) -> str: " - `StableDiffusionXLInpaintDenoiseStep` (inpaint_denoise) is used when mask is provided." ) + # all task with controlnet class StableDiffusionXLControlNetDenoiseStep(AutoPipelineBlocks): block_classes = [StableDiffusionXLInpaintControlNetDenoiseLoop, StableDiffusionXLControlNetDenoiseLoop] @@ -766,6 +793,7 @@ def description(self) -> str: " - `StableDiffusionXLInpaintControlNetDenoiseStep` (inpaint_controlnet_denoise) is used when mask is provided." ) + # all task with or without controlnet class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): block_classes = [StableDiffusionXLControlNetDenoiseStep, StableDiffusionXLDenoiseStep] @@ -779,4 +807,4 @@ def description(self) -> str: "This is a auto pipeline block that works for text2img, img2img and inpainting tasks. And can be used with or without controlnet." " - `StableDiffusionXLDenoiseStep` (denoise) is used when no controlnet_cond is provided (work for text2img, img2img and inpainting tasks)." " - `StableDiffusionXLControlNetDenoiseStep` (controlnet_denoise) is used when controlnet_cond is provided (work for text2img, img2img and inpainting tasks)." - ) \ No newline at end of file + ) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py index a563ffbbbe86..0c088f73c2ee 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py @@ -60,7 +60,6 @@ def retrieve_latents( class StableDiffusionXLIPAdapterStep(PipelineBlock): model_name = "stable-diffusion-xl" - @property def description(self) -> str: return ( @@ -73,13 +72,19 @@ def description(self) -> str: def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("image_encoder", CLIPVisionModelWithProjection), - ComponentSpec("feature_extractor", CLIPImageProcessor, config=FrozenDict({"size": 224, "crop_size": 224}), default_creation_method="from_config"), + ComponentSpec( + "feature_extractor", + CLIPImageProcessor, + config=FrozenDict({"size": 224, "crop_size": 224}), + default_creation_method="from_config", + ), ComponentSpec("unet", UNet2DConditionModel), ComponentSpec( "guider", ClassifierFreeGuidance, config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config"), + default_creation_method="from_config", + ), ] @property @@ -89,16 +94,19 @@ def inputs(self) -> List[InputParam]: "ip_adapter_image", PipelineImageInput, required=True, - description="The image(s) to be used as ip adapter" + description="The image(s) to be used as ip adapter", ) ] - @property def intermediates_outputs(self) -> List[OutputParam]: return [ OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"), - OutputParam("negative_ip_adapter_embeds", type_hint=torch.Tensor, description="Negative IP adapter image embeddings") + OutputParam( + "negative_ip_adapter_embeds", + type_hint=torch.Tensor, + description="Negative IP adapter image embeddings", + ), ] # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image with self -> components @@ -129,7 +137,13 @@ def encode_image(components, image, device, num_images_per_prompt, output_hidden # modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( - self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, prepare_unconditional_embeds + self, + components, + ip_adapter_image, + ip_adapter_image_embeds, + device, + num_images_per_prompt, + prepare_unconditional_embeds, ): image_embeds = [] if prepare_unconditional_embeds: @@ -200,14 +214,11 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt class StableDiffusionXLTextEncoderStep(PipelineBlock): - model_name = "stable-diffusion-xl" @property def description(self) -> str: - return( - "Text Encoder step that generate text_embeddings to guide the image generation" - ) + return "Text Encoder step that generate text_embeddings to guide the image generation" @property def expected_components(self) -> List[ComponentSpec]: @@ -220,7 +231,8 @@ def expected_components(self) -> List[ComponentSpec]: "guider", ClassifierFreeGuidance, config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config"), + default_creation_method="from_config", + ), ] @property @@ -238,22 +250,44 @@ def inputs(self) -> List[InputParam]: InputParam("clip_skip"), ] - @property def intermediates_outputs(self) -> List[OutputParam]: return [ - OutputParam("prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields",description="text embeddings used to guide the image generation"), - OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative text embeddings used to guide the image generation"), - OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="pooled text embeddings used to guide the image generation"), - OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative pooled text embeddings used to guide the image generation"), + OutputParam( + "prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="guider_input_fields", + description="text embeddings used to guide the image generation", + ), + OutputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="guider_input_fields", + description="negative text embeddings used to guide the image generation", + ), + OutputParam( + "pooled_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="guider_input_fields", + description="pooled text embeddings used to guide the image generation", + ), + OutputParam( + "negative_pooled_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="guider_input_fields", + description="negative pooled text embeddings used to guide the image generation", + ), ] @staticmethod def check_inputs(block_state): - - if block_state.prompt is not None and (not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list)): + if block_state.prompt is not None and ( + not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list) + ): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}") - elif block_state.prompt_2 is not None and (not isinstance(block_state.prompt_2, str) and not isinstance(block_state.prompt_2, list)): + elif block_state.prompt_2 is not None and ( + not isinstance(block_state.prompt_2, str) and not isinstance(block_state.prompt_2, list) + ): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(block_state.prompt_2)}") @staticmethod @@ -343,9 +377,15 @@ def encode_prompt( batch_size = prompt_embeds.shape[0] # Define tokenizers and text encoders - tokenizers = [components.tokenizer, components.tokenizer_2] if components.tokenizer is not None else [components.tokenizer_2] + tokenizers = ( + [components.tokenizer, components.tokenizer_2] + if components.tokenizer is not None + else [components.tokenizer_2] + ) text_encoders = ( - [components.text_encoder, components.text_encoder_2] if components.text_encoder is not None else [components.text_encoder_2] + [components.text_encoder, components.text_encoder_2] + if components.text_encoder is not None + else [components.text_encoder_2] ) if prompt_embeds is None: @@ -464,7 +504,9 @@ def encode_prompt( seq_len = negative_prompt_embeds.shape[1] if components.text_encoder_2 is not None: - negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.to( + dtype=components.text_encoder_2.dtype, device=device + ) else: negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.unet.dtype, device=device) @@ -491,7 +533,6 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds - @torch.no_grad() def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: # Get inputs and intermediates @@ -503,7 +544,9 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt # Encode input prompt block_state.text_encoder_lora_scale = ( - block_state.cross_attention_kwargs.get("scale", None) if block_state.cross_attention_kwargs is not None else None + block_state.cross_attention_kwargs.get("scale", None) + if block_state.cross_attention_kwargs is not None + else None ) ( block_state.prompt_embeds, @@ -532,15 +575,11 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt class StableDiffusionXLVaeEncoderStep(PipelineBlock): - model_name = "stable-diffusion-xl" - @property def description(self) -> str: - return ( - "Vae Encoder step that encode the input image into a latent representation" - ) + return "Vae Encoder step that encode the input image into a latent representation" @property def expected_components(self) -> List[ComponentSpec]: @@ -550,7 +589,8 @@ def expected_components(self) -> List[ComponentSpec]: "image_processor", VaeImageProcessor, config=FrozenDict({"vae_scale_factor": 8}), - default_creation_method="from_config"), + default_creation_method="from_config", + ), ] @property @@ -566,16 +606,26 @@ def intermediates_inputs(self) -> List[InputParam]: return [ InputParam("generator"), InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), - InputParam("preprocess_kwargs", type_hint=Optional[dict], description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]")] + InputParam( + "preprocess_kwargs", + type_hint=Optional[dict], + description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]", + ), + ] @property def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation")] + return [ + OutputParam( + "image_latents", + type_hint=torch.Tensor, + description="The latents representing the reference image for image-to-image/inpainting generation", + ) + ] # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components # YiYi TODO: update the _encode_vae_image so that we can use #Coped from def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator): - latents_mean = latents_std = None if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) @@ -609,8 +659,6 @@ def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Ge return image_latents - - @torch.no_grad() def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) @@ -618,7 +666,9 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt block_state.device = components._execution_device block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype - block_state.image = components.image_processor.preprocess(block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs) + block_state.image = components.image_processor.preprocess( + block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs + ) block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype) block_state.batch_size = block_state.image.shape[0] @@ -630,8 +680,9 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators." ) - - block_state.image_latents = self._encode_vae_image(components, image=block_state.image, generator=block_state.generator) + block_state.image_latents = self._encode_vae_image( + components, image=block_state.image, generator=block_state.generator + ) self.add_block_state(state, block_state) @@ -649,20 +700,21 @@ def expected_components(self) -> List[ComponentSpec]: "image_processor", VaeImageProcessor, config=FrozenDict({"vae_scale_factor": 8}), - default_creation_method="from_config"), + default_creation_method="from_config", + ), ComponentSpec( "mask_processor", VaeImageProcessor, - config=FrozenDict({"do_normalize": False, "vae_scale_factor": 8, "do_binarize": True, "do_convert_grayscale": True}), - default_creation_method="from_config"), + config=FrozenDict( + {"do_normalize": False, "vae_scale_factor": 8, "do_binarize": True, "do_convert_grayscale": True} + ), + default_creation_method="from_config", + ), ] - @property def description(self) -> str: - return ( - "Vae encoder step that prepares the image and mask for the inpainting process" - ) + return "Vae encoder step that prepares the image and mask for the inpainting process" @property def inputs(self) -> List[InputParam]: @@ -683,15 +735,26 @@ def intermediates_inputs(self) -> List[InputParam]: @property def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representation of the input image"), - OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"), - OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)"), - OutputParam("crops_coords", type_hint=Optional[Tuple[int, int]], description="The crop coordinates to use for the preprocess/postprocess of the image and mask")] + return [ + OutputParam( + "image_latents", type_hint=torch.Tensor, description="The latents representation of the input image" + ), + OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"), + OutputParam( + "masked_image_latents", + type_hint=torch.Tensor, + description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)", + ), + OutputParam( + "crops_coords", + type_hint=Optional[Tuple[int, int]], + description="The crop coordinates to use for the preprocess/postprocess of the image and mask", + ), + ] # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components # YiYi TODO: update the _encode_vae_image so that we can use #Coped from def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator): - latents_mean = latents_std = None if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) @@ -774,32 +837,45 @@ def prepare_mask_latents( return mask, masked_image_latents - - @torch.no_grad() def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype block_state.device = components._execution_device if block_state.padding_mask_crop is not None: - block_state.crops_coords = components.mask_processor.get_crop_region(block_state.mask_image, block_state.width, block_state.height, pad=block_state.padding_mask_crop) + block_state.crops_coords = components.mask_processor.get_crop_region( + block_state.mask_image, block_state.width, block_state.height, pad=block_state.padding_mask_crop + ) block_state.resize_mode = "fill" else: block_state.crops_coords = None block_state.resize_mode = "default" - block_state.image = components.image_processor.preprocess(block_state.image, height=block_state.height, width=block_state.width, crops_coords=block_state.crops_coords, resize_mode=block_state.resize_mode) + block_state.image = components.image_processor.preprocess( + block_state.image, + height=block_state.height, + width=block_state.width, + crops_coords=block_state.crops_coords, + resize_mode=block_state.resize_mode, + ) block_state.image = block_state.image.to(dtype=torch.float32) - block_state.mask = components.mask_processor.preprocess(block_state.mask_image, height=block_state.height, width=block_state.width, resize_mode=block_state.resize_mode, crops_coords=block_state.crops_coords) + block_state.mask = components.mask_processor.preprocess( + block_state.mask_image, + height=block_state.height, + width=block_state.width, + resize_mode=block_state.resize_mode, + crops_coords=block_state.crops_coords, + ) block_state.masked_image = block_state.image * (block_state.mask < 0.5) block_state.batch_size = block_state.image.shape[0] block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype) - block_state.image_latents = self._encode_vae_image(components, image=block_state.image, generator=block_state.generator) + block_state.image_latents = self._encode_vae_image( + components, image=block_state.image, generator=block_state.generator + ) # 7. Prepare mask latent variables block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents( @@ -816,11 +892,9 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt self.add_block_state(state, block_state) - return components, state - # auto blocks (YiYi TODO: maybe move all the auto blocks to a separate file) # Encode class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks): @@ -830,10 +904,12 @@ class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks): @property def description(self): - return "Vae encoder step that encode the image inputs into their latent representations.\n" + \ - "This is an auto pipeline block that works for both inpainting and img2img tasks.\n" + \ - " - `StableDiffusionXLInpaintVaeEncoderStep` (inpaint) is used when both `mask_image` and `image` are provided.\n" + \ - " - `StableDiffusionXLVaeEncoderStep` (img2img) is used when only `image` is provided." + return ( + "Vae encoder step that encode the image inputs into their latent representations.\n" + + "This is an auto pipeline block that works for both inpainting and img2img tasks.\n" + + " - `StableDiffusionXLInpaintVaeEncoderStep` (inpaint) is used when both `mask_image` and `image` are provided.\n" + + " - `StableDiffusionXLVaeEncoderStep` (img2img) is used when only `image` is provided." + ) class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks, ModularIPAdapterMixin): @@ -844,4 +920,3 @@ class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks, ModularIPAdapterMix @property def description(self): return "Run IP Adapter step if `ip_adapter_image` is provided." - diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_block_mappings.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_block_mappings.py index 9440d72319f3..226266c3f6a7 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_block_mappings.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_block_mappings.py @@ -47,60 +47,74 @@ # YiYi notes: comment out for now, work on this later # block mapping -TEXT2IMAGE_BLOCKS = InsertableOrderedDict([ - ("text_encoder", StableDiffusionXLTextEncoderStep), - ("input", StableDiffusionXLInputStep), - ("set_timesteps", StableDiffusionXLSetTimestepsStep), - ("prepare_latents", StableDiffusionXLPrepareLatentsStep), - ("prepare_add_cond", StableDiffusionXLPrepareAdditionalConditioningStep), - ("denoise", StableDiffusionXLDenoiseLoop), - ("decode", StableDiffusionXLDecodeStep) -]) +TEXT2IMAGE_BLOCKS = InsertableOrderedDict( + [ + ("text_encoder", StableDiffusionXLTextEncoderStep), + ("input", StableDiffusionXLInputStep), + ("set_timesteps", StableDiffusionXLSetTimestepsStep), + ("prepare_latents", StableDiffusionXLPrepareLatentsStep), + ("prepare_add_cond", StableDiffusionXLPrepareAdditionalConditioningStep), + ("denoise", StableDiffusionXLDenoiseLoop), + ("decode", StableDiffusionXLDecodeStep), + ] +) -IMAGE2IMAGE_BLOCKS = InsertableOrderedDict([ - ("text_encoder", StableDiffusionXLTextEncoderStep), - ("image_encoder", StableDiffusionXLVaeEncoderStep), - ("input", StableDiffusionXLInputStep), - ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), - ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep), - ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), - ("denoise", StableDiffusionXLDenoiseLoop), - ("decode", StableDiffusionXLDecodeStep) -]) +IMAGE2IMAGE_BLOCKS = InsertableOrderedDict( + [ + ("text_encoder", StableDiffusionXLTextEncoderStep), + ("image_encoder", StableDiffusionXLVaeEncoderStep), + ("input", StableDiffusionXLInputStep), + ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), + ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep), + ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), + ("denoise", StableDiffusionXLDenoiseLoop), + ("decode", StableDiffusionXLDecodeStep), + ] +) -INPAINT_BLOCKS = InsertableOrderedDict([ - ("text_encoder", StableDiffusionXLTextEncoderStep), - ("image_encoder", StableDiffusionXLInpaintVaeEncoderStep), - ("input", StableDiffusionXLInputStep), - ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), - ("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep), - ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), - ("denoise", StableDiffusionXLInpaintDenoiseLoop), - ("decode", StableDiffusionXLInpaintDecodeStep) -]) +INPAINT_BLOCKS = InsertableOrderedDict( + [ + ("text_encoder", StableDiffusionXLTextEncoderStep), + ("image_encoder", StableDiffusionXLInpaintVaeEncoderStep), + ("input", StableDiffusionXLInputStep), + ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), + ("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep), + ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), + ("denoise", StableDiffusionXLInpaintDenoiseLoop), + ("decode", StableDiffusionXLInpaintDecodeStep), + ] +) -CONTROLNET_BLOCKS = InsertableOrderedDict([ - ("controlnet_input", StableDiffusionXLControlNetInputStep), - ("denoise", StableDiffusionXLControlNetDenoiseStep), -]) +CONTROLNET_BLOCKS = InsertableOrderedDict( + [ + ("controlnet_input", StableDiffusionXLControlNetInputStep), + ("denoise", StableDiffusionXLControlNetDenoiseStep), + ] +) -CONTROLNET_UNION_BLOCKS = InsertableOrderedDict([ - ("controlnet_input", StableDiffusionXLControlNetUnionInputStep), - ("denoise", StableDiffusionXLControlNetDenoiseStep), -]) +CONTROLNET_UNION_BLOCKS = InsertableOrderedDict( + [ + ("controlnet_input", StableDiffusionXLControlNetUnionInputStep), + ("denoise", StableDiffusionXLControlNetDenoiseStep), + ] +) -IP_ADAPTER_BLOCKS = InsertableOrderedDict([ - ("ip_adapter", StableDiffusionXLIPAdapterStep), -]) +IP_ADAPTER_BLOCKS = InsertableOrderedDict( + [ + ("ip_adapter", StableDiffusionXLIPAdapterStep), + ] +) -AUTO_BLOCKS = InsertableOrderedDict([ - ("text_encoder", StableDiffusionXLTextEncoderStep), - ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), - ("image_encoder", StableDiffusionXLAutoVaeEncoderStep), - ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), - ("denoise", StableDiffusionXLAutoDenoiseStep), - ("decode", StableDiffusionXLAutoDecodeStep) -]) +AUTO_BLOCKS = InsertableOrderedDict( + [ + ("text_encoder", StableDiffusionXLTextEncoderStep), + ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), + ("image_encoder", StableDiffusionXLAutoVaeEncoderStep), + ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), + ("denoise", StableDiffusionXLAutoDenoiseStep), + ("decode", StableDiffusionXLAutoDecodeStep), + ] +) SDXL_SUPPORTED_BLOCKS = { @@ -110,8 +124,5 @@ "controlnet": CONTROLNET_BLOCKS, "controlnet_union": CONTROLNET_UNION_BLOCKS, "ip_adapter": IP_ADAPTER_BLOCKS, - "auto": AUTO_BLOCKS + "auto": AUTO_BLOCKS, } - - - diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py index 0f567513c57d..59a723335965 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py @@ -30,7 +30,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name - # YiYi TODO: move to a different file? stable_diffusion_xl_module should have its own folder? # YiYi Notes: model specific components: ## (1) it should inherit from ModularLoader @@ -74,102 +73,285 @@ def num_channels_latents(self): return num_channels_latents - # YiYi Notes: not used yet, maintain a list of schema that can be used across all pipeline blocks SDXL_INPUTS_SCHEMA = { - "prompt": InputParam("prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"), - "prompt_2": InputParam("prompt_2", type_hint=Union[str, List[str]], description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2"), - "negative_prompt": InputParam("negative_prompt", type_hint=Union[str, List[str]], description="The prompt or prompts not to guide the image generation"), - "negative_prompt_2": InputParam("negative_prompt_2", type_hint=Union[str, List[str]], description="The negative prompt or prompts for text_encoder_2"), - "cross_attention_kwargs": InputParam("cross_attention_kwargs", type_hint=Optional[dict], description="Kwargs dictionary passed to the AttentionProcessor"), - "clip_skip": InputParam("clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"), - "image": InputParam("image", type_hint=PipelineImageInput, required=True, description="The image(s) to modify for img2img or inpainting"), - "mask_image": InputParam("mask_image", type_hint=PipelineImageInput, required=True, description="Mask image for inpainting, white pixels will be repainted"), - "generator": InputParam("generator", type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], description="Generator(s) for deterministic generation"), + "prompt": InputParam( + "prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation" + ), + "prompt_2": InputParam( + "prompt_2", + type_hint=Union[str, List[str]], + description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2", + ), + "negative_prompt": InputParam( + "negative_prompt", + type_hint=Union[str, List[str]], + description="The prompt or prompts not to guide the image generation", + ), + "negative_prompt_2": InputParam( + "negative_prompt_2", + type_hint=Union[str, List[str]], + description="The negative prompt or prompts for text_encoder_2", + ), + "cross_attention_kwargs": InputParam( + "cross_attention_kwargs", + type_hint=Optional[dict], + description="Kwargs dictionary passed to the AttentionProcessor", + ), + "clip_skip": InputParam( + "clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder" + ), + "image": InputParam( + "image", + type_hint=PipelineImageInput, + required=True, + description="The image(s) to modify for img2img or inpainting", + ), + "mask_image": InputParam( + "mask_image", + type_hint=PipelineImageInput, + required=True, + description="Mask image for inpainting, white pixels will be repainted", + ), + "generator": InputParam( + "generator", + type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], + description="Generator(s) for deterministic generation", + ), "height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"), "width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"), - "num_images_per_prompt": InputParam("num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"), - "num_inference_steps": InputParam("num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"), - "timesteps": InputParam("timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"), - "sigmas": InputParam("sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"), - "denoising_end": InputParam("denoising_end", type_hint=Optional[float], description="Fraction of denoising process to complete before termination"), + "num_images_per_prompt": InputParam( + "num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt" + ), + "num_inference_steps": InputParam( + "num_inference_steps", type_hint=int, default=50, description="Number of denoising steps" + ), + "timesteps": InputParam( + "timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process" + ), + "sigmas": InputParam( + "sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process" + ), + "denoising_end": InputParam( + "denoising_end", + type_hint=Optional[float], + description="Fraction of denoising process to complete before termination", + ), # YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999 - "strength": InputParam("strength", type_hint=float, default=0.3, description="How much to transform the reference image"), - "denoising_start": InputParam("denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"), - "latents": InputParam("latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"), - "padding_mask_crop": InputParam("padding_mask_crop", type_hint=Optional[Tuple[int, int]], description="Size of margin in crop for image and mask"), - "original_size": InputParam("original_size", type_hint=Optional[Tuple[int, int]], description="Original size of the image for SDXL's micro-conditioning"), - "target_size": InputParam("target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"), - "negative_original_size": InputParam("negative_original_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on image resolution"), - "negative_target_size": InputParam("negative_target_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on target resolution"), - "crops_coords_top_left": InputParam("crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Top-left coordinates for SDXL's micro-conditioning"), - "negative_crops_coords_top_left": InputParam("negative_crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Negative conditioning crop coordinates"), - "aesthetic_score": InputParam("aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"), - "negative_aesthetic_score": InputParam("negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"), + "strength": InputParam( + "strength", type_hint=float, default=0.3, description="How much to transform the reference image" + ), + "denoising_start": InputParam( + "denoising_start", type_hint=Optional[float], description="Starting point of the denoising process" + ), + "latents": InputParam( + "latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation" + ), + "padding_mask_crop": InputParam( + "padding_mask_crop", + type_hint=Optional[Tuple[int, int]], + description="Size of margin in crop for image and mask", + ), + "original_size": InputParam( + "original_size", + type_hint=Optional[Tuple[int, int]], + description="Original size of the image for SDXL's micro-conditioning", + ), + "target_size": InputParam( + "target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning" + ), + "negative_original_size": InputParam( + "negative_original_size", + type_hint=Optional[Tuple[int, int]], + description="Negative conditioning based on image resolution", + ), + "negative_target_size": InputParam( + "negative_target_size", + type_hint=Optional[Tuple[int, int]], + description="Negative conditioning based on target resolution", + ), + "crops_coords_top_left": InputParam( + "crops_coords_top_left", + type_hint=Tuple[int, int], + default=(0, 0), + description="Top-left coordinates for SDXL's micro-conditioning", + ), + "negative_crops_coords_top_left": InputParam( + "negative_crops_coords_top_left", + type_hint=Tuple[int, int], + default=(0, 0), + description="Negative conditioning crop coordinates", + ), + "aesthetic_score": InputParam( + "aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image" + ), + "negative_aesthetic_score": InputParam( + "negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score" + ), "eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"), - "output_type": InputParam("output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"), - "ip_adapter_image": InputParam("ip_adapter_image", type_hint=PipelineImageInput, required=True, description="Image(s) to be used as IP adapter"), - "control_image": InputParam("control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"), - "control_guidance_start": InputParam("control_guidance_start", type_hint=Union[float, List[float]], default=0.0, description="When ControlNet starts applying"), - "control_guidance_end": InputParam("control_guidance_end", type_hint=Union[float, List[float]], default=1.0, description="When ControlNet stops applying"), - "controlnet_conditioning_scale": InputParam("controlnet_conditioning_scale", type_hint=Union[float, List[float]], default=1.0, description="Scale factor for ControlNet outputs"), - "guess_mode": InputParam("guess_mode", type_hint=bool, default=False, description="Enables ControlNet encoder to recognize input without prompts"), - "control_mode": InputParam("control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet") + "output_type": InputParam( + "output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)" + ), + "ip_adapter_image": InputParam( + "ip_adapter_image", + type_hint=PipelineImageInput, + required=True, + description="Image(s) to be used as IP adapter", + ), + "control_image": InputParam( + "control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition" + ), + "control_guidance_start": InputParam( + "control_guidance_start", + type_hint=Union[float, List[float]], + default=0.0, + description="When ControlNet starts applying", + ), + "control_guidance_end": InputParam( + "control_guidance_end", + type_hint=Union[float, List[float]], + default=1.0, + description="When ControlNet stops applying", + ), + "controlnet_conditioning_scale": InputParam( + "controlnet_conditioning_scale", + type_hint=Union[float, List[float]], + default=1.0, + description="Scale factor for ControlNet outputs", + ), + "guess_mode": InputParam( + "guess_mode", + type_hint=bool, + default=False, + description="Enables ControlNet encoder to recognize input without prompts", + ), + "control_mode": InputParam( + "control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet" + ), } SDXL_INTERMEDIATE_INPUTS_SCHEMA = { - "prompt_embeds": InputParam("prompt_embeds", type_hint=torch.Tensor, required=True, description="Text embeddings used to guide image generation"), - "negative_prompt_embeds": InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"), - "pooled_prompt_embeds": InputParam("pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"), - "negative_pooled_prompt_embeds": InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"), + "prompt_embeds": InputParam( + "prompt_embeds", + type_hint=torch.Tensor, + required=True, + description="Text embeddings used to guide image generation", + ), + "negative_prompt_embeds": InputParam( + "negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings" + ), + "pooled_prompt_embeds": InputParam( + "pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings" + ), + "negative_pooled_prompt_embeds": InputParam( + "negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings" + ), "batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"), "dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), - "preprocess_kwargs": InputParam("preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"), - "latents": InputParam("latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"), + "preprocess_kwargs": InputParam( + "preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor" + ), + "latents": InputParam( + "latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process" + ), "timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"), - "num_inference_steps": InputParam("num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"), - "latent_timestep": InputParam("latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"), - "image_latents": InputParam("image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"), + "num_inference_steps": InputParam( + "num_inference_steps", type_hint=int, required=True, description="Number of denoising steps" + ), + "latent_timestep": InputParam( + "latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep" + ), + "image_latents": InputParam( + "image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image" + ), "mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"), - "masked_image_latents": InputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"), - "add_time_ids": InputParam("add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"), - "negative_add_time_ids": InputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"), + "masked_image_latents": InputParam( + "masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting" + ), + "add_time_ids": InputParam( + "add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning" + ), + "negative_add_time_ids": InputParam( + "negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids" + ), "timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"), "noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"), "crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"), - "ip_adapter_embeds": InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"), - "negative_ip_adapter_embeds": InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"), - "images": InputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], required=True, description="Generated images") + "ip_adapter_embeds": InputParam( + "ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter" + ), + "negative_ip_adapter_embeds": InputParam( + "negative_ip_adapter_embeds", + type_hint=List[torch.Tensor], + description="Negative image embeddings for IP-Adapter", + ), + "images": InputParam( + "images", + type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], + required=True, + description="Generated images", + ), } SDXL_INTERMEDIATE_OUTPUTS_SCHEMA = { - "prompt_embeds": OutputParam("prompt_embeds", type_hint=torch.Tensor, description="Text embeddings used to guide image generation"), - "negative_prompt_embeds": OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"), - "pooled_prompt_embeds": OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, description="Pooled text embeddings"), - "negative_pooled_prompt_embeds": OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"), + "prompt_embeds": OutputParam( + "prompt_embeds", type_hint=torch.Tensor, description="Text embeddings used to guide image generation" + ), + "negative_prompt_embeds": OutputParam( + "negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings" + ), + "pooled_prompt_embeds": OutputParam( + "pooled_prompt_embeds", type_hint=torch.Tensor, description="Pooled text embeddings" + ), + "negative_pooled_prompt_embeds": OutputParam( + "negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings" + ), "batch_size": OutputParam("batch_size", type_hint=int, description="Number of prompts"), "dtype": OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), - "image_latents": OutputParam("image_latents", type_hint=torch.Tensor, description="Latents representing reference image"), + "image_latents": OutputParam( + "image_latents", type_hint=torch.Tensor, description="Latents representing reference image" + ), "mask": OutputParam("mask", type_hint=torch.Tensor, description="Mask for inpainting"), - "masked_image_latents": OutputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"), + "masked_image_latents": OutputParam( + "masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting" + ), "crops_coords": OutputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"), "timesteps": OutputParam("timesteps", type_hint=torch.Tensor, description="Timesteps for inference"), "num_inference_steps": OutputParam("num_inference_steps", type_hint=int, description="Number of denoising steps"), - "latent_timestep": OutputParam("latent_timestep", type_hint=torch.Tensor, description="Initial noise level timestep"), + "latent_timestep": OutputParam( + "latent_timestep", type_hint=torch.Tensor, description="Initial noise level timestep" + ), "add_time_ids": OutputParam("add_time_ids", type_hint=torch.Tensor, description="Time ids for conditioning"), - "negative_add_time_ids": OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"), + "negative_add_time_ids": OutputParam( + "negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids" + ), "timestep_cond": OutputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"), "latents": OutputParam("latents", type_hint=torch.Tensor, description="Denoised latents"), "noise": OutputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"), - "ip_adapter_embeds": OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"), - "negative_ip_adapter_embeds": OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"), - "images": OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="Generated images") + "ip_adapter_embeds": OutputParam( + "ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter" + ), + "negative_ip_adapter_embeds": OutputParam( + "negative_ip_adapter_embeds", + type_hint=List[torch.Tensor], + description="Negative image embeddings for IP-Adapter", + ), + "images": OutputParam( + "images", + type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], + description="Generated images", + ), } SDXL_OUTPUTS_SCHEMA = { - "images": OutputParam("images", type_hint=Union[Tuple[Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]]], StableDiffusionXLPipelineOutput], description="The final generated images") + "images": OutputParam( + "images", + type_hint=Union[ + Tuple[Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]]], StableDiffusionXLPipelineOutput + ], + description="The final generated images", + ) } - diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py index 981f4d7e033a..eee395a860a7 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py @@ -28,19 +28,24 @@ class StableDiffusionXLAutoPipeline(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep, StableDiffusionXLAutoBeforeDenoiseStep, StableDiffusionXLAutoDenoiseStep, StableDiffusionXLAutoDecodeStep] + block_classes = [ + StableDiffusionXLTextEncoderStep, + StableDiffusionXLAutoIPAdapterStep, + StableDiffusionXLAutoVaeEncoderStep, + StableDiffusionXLAutoBeforeDenoiseStep, + StableDiffusionXLAutoDenoiseStep, + StableDiffusionXLAutoDecodeStep, + ] block_names = ["text_encoder", "ip_adapter", "image_encoder", "before_denoise", "denoise", "decoder"] @property def description(self): - return "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion XL.\n" + \ - "- for image-to-image generation, you need to provide either `image` or `image_latents`\n" + \ - "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n" + \ - "- to run the controlnet workflow, you need to provide `control_image`\n" + \ - "- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n" + \ - "- to run the ip_adapter workflow, you need to provide `ip_adapter_image`\n" + \ - "- for text-to-image generation, all you need to provide is `prompt`" - - - - + return ( + "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion XL.\n" + + "- for image-to-image generation, you need to provide either `image` or `image_latents`\n" + + "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n" + + "- to run the controlnet workflow, you need to provide `control_image`\n" + + "- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n" + + "- to run the ip_adapter workflow, you need to provide `ip_adapter_image`\n" + + "- for text-to-image generation, all you need to provide is `prompt`" + ) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index effef5e42dc3..7e48cca09393 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -374,6 +374,7 @@ def maybe_raise_or_warn( # a simpler version of get_class_obj_and_candidates, it won't work with custom code def simple_get_class_obj(library_name, class_name): from diffusers import pipelines + is_pipeline_module = hasattr(pipelines, library_name) if is_pipeline_module: @@ -385,6 +386,7 @@ def simple_get_class_obj(library_name, class_name): return class_obj + def get_class_obj_and_candidates( library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None ): diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index ccc714289df9..0375fbb0856a 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1986,12 +1986,13 @@ def from_pipe(cls, pipeline, **kwargs): f"{'' if k.startswith('_') else '_'}{k}": v for k, v in original_config.items() if k not in pipeline_kwargs } - optional_components = pipeline._optional_components if hasattr(pipeline, "_optional_components") and pipeline._optional_components else [] + optional_components = ( + pipeline._optional_components + if hasattr(pipeline, "_optional_components") and pipeline._optional_components + else [] + ) missing_modules = ( - set(expected_modules) - - set(optional_components) - - set(pipeline_kwargs.keys()) - - set(true_optional_modules) + set(expected_modules) - set(optional_components) - set(pipeline_kwargs.keys()) - set(true_optional_modules) ) if len(missing_modules) > 0: From 9530245e17c8c293a93298ec284512aa6052794b Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 25 Jun 2025 12:10:35 +0200 Subject: [PATCH 093/170] correct code format --- docs/source/en/modular_diffusers/quicktour.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/en/modular_diffusers/quicktour.md b/docs/source/en/modular_diffusers/quicktour.md index e8ef62403fd1..94e59b468e27 100644 --- a/docs/source/en/modular_diffusers/quicktour.md +++ b/docs/source/en/modular_diffusers/quicktour.md @@ -620,6 +620,7 @@ t2i_pipeline.doc ``` + ```py import torch from diffusers.modular_pipelines import SequentialPipelineBlocks @@ -862,7 +863,7 @@ image = decoder_node(latents=latents, output="images")[0] refined_image = decoder_node(latents=refined_latents, output="images")[0] ``` -# YiYi TODO: maybe more on controlnet/lora/ip-adapter +## YiYi TODO: maybe more on controlnet/lora/ip-adapter From c437ae72c6cc4a270d0f765c4b4cfb78f246b1ff Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 25 Jun 2025 23:26:59 +0200 Subject: [PATCH 094/170] copies --- .../modular_pipelines/modular_pipeline.py | 2 +- .../stable_diffusion_xl/before_denoise.py | 28 +-- .../stable_diffusion_xl/decoders.py | 5 +- .../stable_diffusion_xl/encoders.py | 5 +- src/diffusers/utils/dummy_pt_objects.py | 229 ++++++++++++++++-- .../dummy_torch_and_transformers_objects.py | 45 ++-- 6 files changed, 261 insertions(+), 53 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 24d6e4caec13..24b2779941e2 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -1814,7 +1814,7 @@ def device(self) -> torch.device: return torch.device("cpu") @property - # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline._execution_device + # Modified from diffusers.pipelines.pipeline_utils.DiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py index fd73d8d74943..ed4dec87f743 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py @@ -451,7 +451,7 @@ def intermediates_outputs(self) -> List[str]: ), ] - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps with self -> components + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps with self->components def get_timesteps(self, components, num_inference_steps, strength, device, denoising_start=None): # get the original timestep using init_timestep if denoising_start is None: @@ -697,7 +697,7 @@ def intermediates_outputs(self) -> List[str]: ), ] - # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self->components # YiYi TODO: update the _encode_vae_image so that we can use #Coped from @staticmethod def _encode_vae_image(components, image: torch.Tensor, generator: torch.Generator): @@ -1042,10 +1042,9 @@ def check_inputs(components, block_state): f"`height` and `width` have to be divisible by {components.vae_scale_factor} but are {block_state.height} and {block_state.width}." ) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with self -> components - @staticmethod + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with self->components def prepare_latents( - components, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None + self, components, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None ): shape = ( batch_size, @@ -1167,9 +1166,9 @@ def intermediates_outputs(self) -> List[OutputParam]: OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM"), ] - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids with self -> components - @staticmethod - def _get_add_time_ids_img2img( + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids with self->components + def _get_add_time_ids( + self, components, original_size, crops_coords_top_left, @@ -1221,9 +1220,8 @@ def _get_add_time_ids_img2img( return add_time_ids, add_neg_time_ids # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding - @staticmethod def get_guidance_scale_embedding( - w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 ) -> torch.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 @@ -1273,7 +1271,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt if block_state.negative_target_size is None: block_state.negative_target_size = block_state.target_size - block_state.add_time_ids, block_state.negative_add_time_ids = self._get_add_time_ids_img2img( + block_state.add_time_ids, block_state.negative_add_time_ids = self._get_add_time_ids( components, block_state.original_size, block_state.crops_coords_top_left, @@ -1372,10 +1370,9 @@ def intermediates_outputs(self) -> List[OutputParam]: OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM"), ] - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids with self -> components - @staticmethod + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids with self->components def _get_add_time_ids( - components, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + self, components, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None ): add_time_ids = list(original_size + crops_coords_top_left + target_size) @@ -1393,9 +1390,8 @@ def _get_add_time_ids( return add_time_ids # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding - @staticmethod def get_guidance_scale_embedding( - w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 ) -> torch.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py index 96397a5f7648..921dcfaa2d56 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py @@ -81,9 +81,8 @@ def intermediates_outputs(self) -> List[str]: ) ] - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae with self -> components - @staticmethod - def upcast_vae(components): + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae with self->components + def upcast_vae(self, components): dtype = components.vae.dtype components.vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = isinstance( diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py index 0c088f73c2ee..eff01ac60d24 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py @@ -109,9 +109,8 @@ def intermediates_outputs(self) -> List[OutputParam]: ), ] - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image with self -> components - @staticmethod - def encode_image(components, image, device, num_images_per_prompt, output_hidden_states=None): + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image with self->components + def encode_image(self, components, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(components.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 43f0b5a3d00f..496039a436e5 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -2,6 +2,111 @@ from ..utils import DummyObject, requires_backends +class AdaptiveProjectedGuidance(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class AutoGuidance(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class ClassifierFreeGuidance(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class ClassifierFreeZeroStarGuidance(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class SkipLayerGuidance(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class SmoothedEnergyGuidance(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class TangentialClassifierFreeGuidance(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class FasterCacheConfig(metaclass=DummyObject): _backends = ["torch"] @@ -32,6 +137,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class LayerSkipConfig(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class PyramidAttentionBroadcastConfig(metaclass=DummyObject): _backends = ["torch"] @@ -47,10 +167,29 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class SmoothedEnergyGuidanceConfig(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + def apply_faster_cache(*args, **kwargs): requires_backends(apply_faster_cache, ["torch"]) +def apply_layer_skip(*args, **kwargs): + requires_backends(apply_layer_skip, ["torch"]) + + def apply_pyramid_attention_broadcast(*args, **kwargs): requires_backends(apply_pyramid_attention_broadcast, ["torch"]) @@ -1180,6 +1319,81 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class ComponentsManager(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class ComponentSpec(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class ModularLoader(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class ModularPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class ModularPipelineBlocks(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + def get_constant_schedule(*args, **kwargs): requires_backends(get_constant_schedule, ["torch"]) @@ -1463,21 +1677,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class ModularLoader(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - class PNDMPipeline(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index d76bb8ab3e08..782df0204f3f 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2,6 +2,36 @@ from ..utils import DummyObject, requires_backends +class StableDiffusionXLAutoPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionXLModularLoader(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class AllegroPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -2582,21 +2612,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class StableDiffusionXLModularLoader(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - class StableDiffusionXLPAGImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From f3453f05ff29cca86658d3d793c960c471a3b2e3 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 26 Jun 2025 00:47:33 +0200 Subject: [PATCH 095/170] copy --- .../stable_diffusion_xl/before_denoise.py | 22 +++++++++---------- .../stable_diffusion_xl/decoders.py | 3 ++- .../stable_diffusion_xl/encoders.py | 3 ++- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py index ed4dec87f743..0484d5dd7ac5 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py @@ -451,8 +451,9 @@ def intermediates_outputs(self) -> List[str]: ), ] + @staticmethod # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps with self->components - def get_timesteps(self, components, num_inference_steps, strength, device, denoising_start=None): + def get_timesteps(components, num_inference_steps, strength, device, denoising_start=None): # get the original timestep using init_timestep if denoising_start is None: init_timestep = min(int(num_inference_steps * strength), num_inference_steps) @@ -1042,15 +1043,14 @@ def check_inputs(components, block_state): f"`height` and `width` have to be divisible by {components.vae_scale_factor} but are {block_state.height} and {block_state.width}." ) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with self->components - def prepare_latents( - self, components, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None - ): + @staticmethod + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with self->comp + def prepare_latents(comp, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, - int(height) // components.vae_scale_factor, - int(width) // components.vae_scale_factor, + int(height) // comp.vae_scale_factor, + int(width) // comp.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -1064,7 +1064,7 @@ def prepare_latents( latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler - latents = latents * components.scheduler.init_noise_sigma + latents = latents * comp.scheduler.init_noise_sigma return latents @torch.no_grad() @@ -1166,9 +1166,9 @@ def intermediates_outputs(self) -> List[OutputParam]: OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM"), ] + @staticmethod # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids with self->components def _get_add_time_ids( - self, components, original_size, crops_coords_top_left, @@ -1369,10 +1369,10 @@ def intermediates_outputs(self) -> List[OutputParam]: ), OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM"), ] - + @staticmethod # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids with self->components def _get_add_time_ids( - self, components, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + components, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None ): add_time_ids = list(original_size + crops_coords_top_left + target_size) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py index 921dcfaa2d56..1f1a8c477de5 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py @@ -81,8 +81,9 @@ def intermediates_outputs(self) -> List[str]: ) ] + @staticmethod # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae with self->components - def upcast_vae(self, components): + def upcast_vae(components): dtype = components.vae.dtype components.vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = isinstance( diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py index eff01ac60d24..d4ec17ada5ea 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py @@ -109,8 +109,9 @@ def intermediates_outputs(self) -> List[OutputParam]: ), ] + @staticmethod # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image with self->components - def encode_image(self, components, image, device, num_images_per_prompt, output_hidden_states=None): + def encode_image(components, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(components.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): From a82e211f89634d08203d5f5ecce72f69e65d63df Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 26 Jun 2025 00:48:23 +0200 Subject: [PATCH 096/170] style --- .../modular_pipelines/stable_diffusion_xl/before_denoise.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py index 0484d5dd7ac5..be04d2cfd608 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py @@ -1369,6 +1369,7 @@ def intermediates_outputs(self) -> List[OutputParam]: ), OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM"), ] + @staticmethod # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids with self->components def _get_add_time_ids( From a33206d22b8b9208025b0c5137fe5afc4f5e9829 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 26 Jun 2025 01:31:51 +0200 Subject: [PATCH 097/170] fix --- src/diffusers/modular_pipelines/modular_pipeline.py | 5 ++--- src/diffusers/modular_pipelines/modular_pipeline_utils.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 24b2779941e2..750126e43ea7 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -1351,8 +1351,8 @@ def get_inputs(self): input_param.required = False return combined_inputs - # Copied from SequentialPipelineBlocks @property + # Copied from diffusers.modular_pipelines.modular_pipeline.SequentialPipelineBlocks.inputs def inputs(self): return self.get_inputs() @@ -1366,7 +1366,7 @@ def intermediates_inputs(self): intermediates.append(loop_intermediate_input) return intermediates - # Copied from SequentialPipelineBlocks + # modified from SequentialPipelineBlocks def get_intermediates_inputs(self): inputs = [] outputs = set() @@ -1429,7 +1429,6 @@ def intermediates_outputs(self) -> List[str]: return combined_outputs # YiYi TODO: this need to be thought about more - # copied from SequentialPipelineBlocks @property def outputs(self) -> List[str]: return next(reversed(self.blocks.values())).intermediates_outputs diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index 2d86c2540072..015bec89837d 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -80,7 +80,7 @@ class ComponentSpec: name: Optional[str] = None type_hint: Optional[Type] = None description: Optional[str] = None - config: Optional[FrozenDict[str, Any]] = None + config: Optional[FrozenDict] = None # YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True}) subfolder: Optional[str] = field(default=None, metadata={"loading": True}) From 75e62385f5f05d7656a058bf23d9cde3b6c487b6 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 26 Jun 2025 01:35:00 +0200 Subject: [PATCH 098/170] revert changes in pipelines.stable_diffusion_xl folder, can seperate PR later --- .../controlnet/pipeline_controlnet_sd_xl_img2img.py | 11 ++++++----- .../pipeline_controlnet_union_sd_xl_img2img.py | 11 ++++++----- .../pipelines/kolors/pipeline_kolors_img2img.py | 11 ++++++----- .../pag/pipeline_pag_controlnet_sd_xl_img2img.py | 11 ++++++----- .../pipelines/pag/pipeline_pag_sd_xl_img2img.py | 11 ++++++----- .../pipeline_stable_diffusion_xl_img2img.py | 11 ++++++----- 6 files changed, 36 insertions(+), 30 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index 6c45c347950d..07d07cc60e59 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -912,6 +912,12 @@ def prepare_latents( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + # Offload text encoder if `enable_model_cpu_offload` was enabled if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.text_encoder_2.to("cpu") @@ -925,11 +931,6 @@ def prepare_latents( init_latents = image else: - latents_mean = latents_std = None - if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: - latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: - latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) # make sure the VAE is in float32 mode, as it overflows in float16 if self.vae.config.force_upcast: image = image.float() diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py index d3b0e372609e..4b283710d0f4 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py @@ -867,6 +867,12 @@ def prepare_latents( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + # Offload text encoder if `enable_model_cpu_offload` was enabled if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.text_encoder_2.to("cpu") @@ -880,11 +886,6 @@ def prepare_latents( init_latents = image else: - latents_mean = latents_std = None - if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: - latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: - latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) # make sure the VAE is in float32 mode, as it overflows in float16 if self.vae.config.force_upcast: image = image.float() diff --git a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py index 0ea09c43b053..a9af5512cc80 100644 --- a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py +++ b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py @@ -609,6 +609,12 @@ def prepare_latents( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + # Offload text encoder if `enable_model_cpu_offload` was enabled if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.text_encoder_2.to("cpu") @@ -622,11 +628,6 @@ def prepare_latents( init_latents = image else: - latents_mean = latents_std = None - if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: - latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: - latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) # make sure the VAE is in float32 mode, as it overflows in float16 if self.vae.config.force_upcast: image = image.float() diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py index a2b2f59b8ce3..8059fa5c7942 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py @@ -917,6 +917,12 @@ def prepare_latents( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + # Offload text encoder if `enable_model_cpu_offload` was enabled if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.text_encoder_2.to("cpu") @@ -930,11 +936,6 @@ def prepare_latents( init_latents = image else: - latents_mean = latents_std = None - if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: - latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: - latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) # make sure the VAE is in float32 mode, as it overflows in float16 if self.vae.config.force_upcast: image = image.float() diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py index 07d3d92c3138..e582558f5246 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py @@ -707,6 +707,12 @@ def prepare_latents( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + # Offload text encoder if `enable_model_cpu_offload` was enabled if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.text_encoder_2.to("cpu") @@ -720,11 +726,6 @@ def prepare_latents( init_latents = image else: - latents_mean = latents_std = None - if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: - latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: - latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) # make sure the VAE is in float32 mode, as it overflows in float16 if self.vae.config.force_upcast: image = image.float() diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index fde2392dca4d..31a701d08211 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -695,6 +695,12 @@ def prepare_latents( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + # Offload text encoder if `enable_model_cpu_offload` was enabled if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.text_encoder_2.to("cpu") @@ -708,11 +714,6 @@ def prepare_latents( init_latents = image else: - latents_mean = latents_std = None - if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: - latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: - latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) # make sure the VAE is in float32 mode, as it overflows in float16 if self.vae.config.force_upcast: image = image.float() From 129d658da7b70bfce09b0163efbfa5dffaa85976 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 26 Jun 2025 01:36:43 +0200 Subject: [PATCH 099/170] oops, fix --- .../controlnet/pipeline_controlnet_sd_xl_img2img.py | 10 +++++----- .../pipeline_controlnet_union_sd_xl_img2img.py | 10 +++++----- .../pipelines/kolors/pipeline_kolors_img2img.py | 10 +++++----- .../pag/pipeline_pag_controlnet_sd_xl_img2img.py | 10 +++++----- .../pipelines/pag/pipeline_pag_sd_xl_img2img.py | 10 +++++----- .../pipeline_stable_diffusion_xl_img2img.py | 10 +++++----- 6 files changed, 30 insertions(+), 30 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index 07d07cc60e59..526e1ffcb2cc 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -912,11 +912,11 @@ def prepare_latents( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) - latents_mean = latents_std = None - if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: - latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: - latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) # Offload text encoder if `enable_model_cpu_offload` was enabled if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py index 4b283710d0f4..82ef4b6391eb 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py @@ -867,11 +867,11 @@ def prepare_latents( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) - latents_mean = latents_std = None - if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: - latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: - latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) # Offload text encoder if `enable_model_cpu_offload` was enabled if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: diff --git a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py index a9af5512cc80..e3cf4f227624 100644 --- a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py +++ b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py @@ -609,11 +609,11 @@ def prepare_latents( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) - latents_mean = latents_std = None - if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: - latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: - latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) # Offload text encoder if `enable_model_cpu_offload` was enabled if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py index 8059fa5c7942..913a647fae3e 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py @@ -917,11 +917,11 @@ def prepare_latents( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) - latents_mean = latents_std = None - if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: - latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: - latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) # Offload text encoder if `enable_model_cpu_offload` was enabled if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py index e582558f5246..8c355a5fb129 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py @@ -707,11 +707,11 @@ def prepare_latents( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) - latents_mean = latents_std = None - if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: - latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: - latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) # Offload text encoder if `enable_model_cpu_offload` was enabled if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 31a701d08211..e63c7a55ce7b 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -695,11 +695,11 @@ def prepare_latents( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) - latents_mean = latents_std = None - if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: - latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: - latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) # Offload text encoder if `enable_model_cpu_offload` was enabled if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: From da4242d467482f2a2dc7247e871d3ba7b8525927 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 26 Jun 2025 03:36:34 +0200 Subject: [PATCH 100/170] use diffusers ModelHook, raise a import error for accelerate inside enable_auto_cpu_offload --- src/diffusers/modular_pipelines/components_manager.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/diffusers/modular_pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py index 8f5c04d8a94d..59f2509be427 100644 --- a/src/diffusers/modular_pipelines/components_manager.py +++ b/src/diffusers/modular_pipelines/components_manager.py @@ -26,9 +26,11 @@ logging, ) +from ..hooks import ModelHook + if is_accelerate_available(): - from accelerate.hooks import ModelHook, add_hook_to_module, remove_hook_from_module + from accelerate.hooks import add_hook_to_module, remove_hook_from_module from accelerate.state import PartialState from accelerate.utils import send_to_device from accelerate.utils.memory import clear_device_cache @@ -67,6 +69,7 @@ class CustomOffloadHook(ModelHook): The device on which the model should be executed. Will default to the MPS device if it's available, then GPU 0 if there is a GPU, and finally to the CPU. """ + no_grad = False def __init__( self, @@ -538,6 +541,10 @@ def matches_pattern(component_id, pattern, exact_match=False): raise ValueError(f"Invalid type for names: {type(names)}") def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = "cuda", memory_reserve_margin="3GB"): + + if not is_accelerate_available(): + raise ImportError("Make sure to install accelerate to use auto_cpu_offload") + for name, component in self.components.items(): if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"): remove_hook_from_module(component, recurse=True) From ab6d63407a5c46cb96205f395929386d21dee251 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 26 Jun 2025 03:37:58 +0200 Subject: [PATCH 101/170] style --- src/diffusers/modular_pipelines/components_manager.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/diffusers/modular_pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py index 59f2509be427..60d78578940f 100644 --- a/src/diffusers/modular_pipelines/components_manager.py +++ b/src/diffusers/modular_pipelines/components_manager.py @@ -21,13 +21,12 @@ import torch +from ..hooks import ModelHook from ..utils import ( is_accelerate_available, logging, ) -from ..hooks import ModelHook - if is_accelerate_available(): from accelerate.hooks import add_hook_to_module, remove_hook_from_module @@ -69,6 +68,7 @@ class CustomOffloadHook(ModelHook): The device on which the model should be executed. Will default to the MPS device if it's available, then GPU 0 if there is a GPU, and finally to the CPU. """ + no_grad = False def __init__( @@ -541,10 +541,9 @@ def matches_pattern(component_id, pattern, exact_match=False): raise ValueError(f"Invalid type for names: {type(names)}") def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = "cuda", memory_reserve_margin="3GB"): - if not is_accelerate_available(): raise ImportError("Make sure to install accelerate to use auto_cpu_offload") - + for name, component in self.components.items(): if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"): remove_hook_from_module(component, recurse=True) From 7492e331b44c4841ff3fc8431240ee7d8e63b137 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 26 Jun 2025 03:43:10 +0200 Subject: [PATCH 102/170] fix --- src/diffusers/modular_pipelines/modular_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 750126e43ea7..186a7cd0e80a 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -70,8 +70,8 @@ class PipelineState: inputs: Dict[str, Any] = field(default_factory=dict) intermediates: Dict[str, Any] = field(default_factory=dict) - input_kwargs: Dict[str, list[str, Any]] = field(default_factory=dict) - intermediate_kwargs: Dict[str, list[str, Any]] = field(default_factory=dict) + input_kwargs: Dict[str, List[str]] = field(default_factory=dict) + intermediate_kwargs: Dict[str, List[str]] = field(default_factory=dict) def add_input(self, key: str, value: Any, kwargs_type: str = None): """ From b92cda25e21ed8d3769dddf5a623029cfde19d3b Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 26 Jun 2025 12:39:13 +0200 Subject: [PATCH 103/170] move quicktour to first page --- docs/source/en/_toctree.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 4731cc4d57df..f3d587d067c8 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -92,10 +92,10 @@ title: API Reference title: Hybrid Inference - sections: - - local: modular_diffusers/developer_guide - title: Developer Guide - local: modular_diffusers/quicktour title: Quicktour + - local: modular_diffusers/developer_guide + title: Developer Guide title: Modular Diffusers - sections: - local: using-diffusers/consisid From 61772f0994e8d5a9e17287799dcd7f759c3c1696 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 26 Jun 2025 12:39:53 +0200 Subject: [PATCH 104/170] updatee a comment --- src/diffusers/modular_pipelines/modular_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 186a7cd0e80a..9981f87efcaf 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -635,9 +635,9 @@ def __init__(self): ) default_blocks = [t for t in self.block_trigger_inputs if t is None] # can only have 1 or 0 default block, and has to put in the last - # the order of blocksmatters here because the first block with matching trigger will be dispatched + # the order of blocks matters here because the first block with matching trigger will be dispatched # e.g. blocks = [inpaint, img2img] and block_trigger_inputs = ["mask", "image"] - # if both mask and image are provided, it is inpaint; if only image is provided, it is img2img + # as long as mask is provided, it is inpaint; if only image is provided, it is img2img if len(default_blocks) > 1 or (len(default_blocks) == 1 and self.block_trigger_inputs[-1] is not None): raise ValueError( f"In {self.__class__.__name__}, exactly one None must be specified as the last element " From 9abac85f7714d7a89369739e511860d5a7f9c770 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 26 Jun 2025 12:40:38 +0200 Subject: [PATCH 105/170] remove mapping file, move to preeset.py --- .../modular_block_mappings.py | 128 ------------------ 1 file changed, 128 deletions(-) delete mode 100644 src/diffusers/modular_pipelines/stable_diffusion_xl/modular_block_mappings.py diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_block_mappings.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_block_mappings.py deleted file mode 100644 index 226266c3f6a7..000000000000 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_block_mappings.py +++ /dev/null @@ -1,128 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from ..modular_pipeline_utils import InsertableOrderedDict -from .before_denoise import ( - StableDiffusionXLAutoBeforeDenoiseStep, - StableDiffusionXLControlNetInputStep, - StableDiffusionXLControlNetUnionInputStep, - StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, - StableDiffusionXLImg2ImgPrepareLatentsStep, - StableDiffusionXLImg2ImgSetTimestepsStep, - StableDiffusionXLInpaintPrepareLatentsStep, - StableDiffusionXLInputStep, - StableDiffusionXLPrepareAdditionalConditioningStep, - StableDiffusionXLPrepareLatentsStep, - StableDiffusionXLSetTimestepsStep, -) -from .decoders import StableDiffusionXLAutoDecodeStep, StableDiffusionXLDecodeStep, StableDiffusionXLInpaintDecodeStep - -# Import all the necessary block classes -from .denoise import ( - StableDiffusionXLAutoDenoiseStep, - StableDiffusionXLControlNetDenoiseStep, - StableDiffusionXLDenoiseLoop, - StableDiffusionXLInpaintDenoiseLoop, -) -from .encoders import ( - StableDiffusionXLAutoIPAdapterStep, - StableDiffusionXLAutoVaeEncoderStep, - StableDiffusionXLInpaintVaeEncoderStep, - StableDiffusionXLIPAdapterStep, - StableDiffusionXLTextEncoderStep, - StableDiffusionXLVaeEncoderStep, -) - - -# YiYi notes: comment out for now, work on this later -# block mapping -TEXT2IMAGE_BLOCKS = InsertableOrderedDict( - [ - ("text_encoder", StableDiffusionXLTextEncoderStep), - ("input", StableDiffusionXLInputStep), - ("set_timesteps", StableDiffusionXLSetTimestepsStep), - ("prepare_latents", StableDiffusionXLPrepareLatentsStep), - ("prepare_add_cond", StableDiffusionXLPrepareAdditionalConditioningStep), - ("denoise", StableDiffusionXLDenoiseLoop), - ("decode", StableDiffusionXLDecodeStep), - ] -) - -IMAGE2IMAGE_BLOCKS = InsertableOrderedDict( - [ - ("text_encoder", StableDiffusionXLTextEncoderStep), - ("image_encoder", StableDiffusionXLVaeEncoderStep), - ("input", StableDiffusionXLInputStep), - ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), - ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep), - ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), - ("denoise", StableDiffusionXLDenoiseLoop), - ("decode", StableDiffusionXLDecodeStep), - ] -) - -INPAINT_BLOCKS = InsertableOrderedDict( - [ - ("text_encoder", StableDiffusionXLTextEncoderStep), - ("image_encoder", StableDiffusionXLInpaintVaeEncoderStep), - ("input", StableDiffusionXLInputStep), - ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), - ("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep), - ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), - ("denoise", StableDiffusionXLInpaintDenoiseLoop), - ("decode", StableDiffusionXLInpaintDecodeStep), - ] -) - -CONTROLNET_BLOCKS = InsertableOrderedDict( - [ - ("controlnet_input", StableDiffusionXLControlNetInputStep), - ("denoise", StableDiffusionXLControlNetDenoiseStep), - ] -) - -CONTROLNET_UNION_BLOCKS = InsertableOrderedDict( - [ - ("controlnet_input", StableDiffusionXLControlNetUnionInputStep), - ("denoise", StableDiffusionXLControlNetDenoiseStep), - ] -) - -IP_ADAPTER_BLOCKS = InsertableOrderedDict( - [ - ("ip_adapter", StableDiffusionXLIPAdapterStep), - ] -) - -AUTO_BLOCKS = InsertableOrderedDict( - [ - ("text_encoder", StableDiffusionXLTextEncoderStep), - ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), - ("image_encoder", StableDiffusionXLAutoVaeEncoderStep), - ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), - ("denoise", StableDiffusionXLAutoDenoiseStep), - ("decode", StableDiffusionXLAutoDecodeStep), - ] -) - - -SDXL_SUPPORTED_BLOCKS = { - "text2img": TEXT2IMAGE_BLOCKS, - "img2img": IMAGE2IMAGE_BLOCKS, - "inpaint": INPAINT_BLOCKS, - "controlnet": CONTROLNET_BLOCKS, - "controlnet_union": CONTROLNET_UNION_BLOCKS, - "ip_adapter": IP_ADAPTER_BLOCKS, - "auto": AUTO_BLOCKS, -} From 84f4b27dfa7d1aa1415a1c820b9871feeb5b61bb Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 26 Jun 2025 12:41:16 +0200 Subject: [PATCH 106/170] modular_pipeline_presets.py -> modular_blocks_presets.py --- .../modular_blocks_presets.py | 380 ++++++++++++++++++ .../modular_pipeline_presets.py | 51 --- 2 files changed, 380 insertions(+), 51 deletions(-) create mode 100644 src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks_presets.py delete mode 100644 src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks_presets.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks_presets.py new file mode 100644 index 000000000000..fb1b99a7086b --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks_presets.py @@ -0,0 +1,380 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import InsertableOrderedDict +from .before_denoise import ( + StableDiffusionXLControlNetInputStep, + StableDiffusionXLControlNetUnionInputStep, + StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, + StableDiffusionXLImg2ImgPrepareLatentsStep, + StableDiffusionXLImg2ImgSetTimestepsStep, + StableDiffusionXLInpaintPrepareLatentsStep, + StableDiffusionXLInputStep, + StableDiffusionXLPrepareAdditionalConditioningStep, + StableDiffusionXLPrepareLatentsStep, + StableDiffusionXLSetTimestepsStep, +) +from .decoders import ( + StableDiffusionXLDecodeStep, + StableDiffusionXLInpaintOverlayMaskStep, +) +from .denoise import ( + StableDiffusionXLControlNetDenoiseStep, + StableDiffusionXLDenoiseStep, + StableDiffusionXLInpaintControlNetDenoiseStep, + StableDiffusionXLInpaintDenoiseStep, +) +from .encoders import ( + StableDiffusionXLInpaintVaeEncoderStep, + StableDiffusionXLIPAdapterStep, + StableDiffusionXLTextEncoderStep, + StableDiffusionXLVaeEncoderStep, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# auto blocks & sequential blocks & mappings + + +# vae encoder (run before before_denoise) +class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintVaeEncoderStep, StableDiffusionXLVaeEncoderStep] + block_names = ["inpaint", "img2img"] + block_trigger_inputs = ["mask_image", "image"] + + @property + def description(self): + return ( + "Vae encoder step that encode the image inputs into their latent representations.\n" + + "This is an auto pipeline block that works for both inpainting and img2img tasks.\n" + + " - `StableDiffusionXLInpaintVaeEncoderStep` (inpaint) is used when `mask_image` is provided.\n" + + " - `StableDiffusionXLVaeEncoderStep` (img2img) is used when only `image` is provided." + + " - if neither `mask_image` nor `image` is provided, step will be skipped." + ) + + +# optional ip-adapter (run before before_denoise) +class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLIPAdapterStep] + block_names = ["ip_adapter"] + block_trigger_inputs = ["ip_adapter_image"] + + @property + def description(self): + return "Run IP Adapter step if `ip_adapter_image` is provided." + + +# before_denoise: text2img +class StableDiffusionXLBeforeDenoiseStep(SequentialPipelineBlocks): + block_classes = [ + StableDiffusionXLInputStep, + StableDiffusionXLSetTimestepsStep, + StableDiffusionXLPrepareLatentsStep, + StableDiffusionXLPrepareAdditionalConditioningStep, + ] + block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"] + + @property + def description(self): + return ( + "Before denoise step that prepare the inputs for the denoise step.\n" + + "This is a sequential pipeline blocks:\n" + + " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + + " - `StableDiffusionXLSetTimestepsStep` is used to set the timesteps\n" + + " - `StableDiffusionXLPrepareLatentsStep` is used to prepare the latents\n" + + " - `StableDiffusionXLPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + ) + + +# before_denoise: img2img +class StableDiffusionXLImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks): + block_classes = [ + StableDiffusionXLInputStep, + StableDiffusionXLImg2ImgSetTimestepsStep, + StableDiffusionXLImg2ImgPrepareLatentsStep, + StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, + ] + block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"] + + @property + def description(self): + return ( + "Before denoise step that prepare the inputs for the denoise step for img2img task.\n" + + "This is a sequential pipeline blocks:\n" + + " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + + " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" + + " - `StableDiffusionXLImg2ImgPrepareLatentsStep` is used to prepare the latents\n" + + " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + ) + + +# before_denoise: inpainting +class StableDiffusionXLInpaintBeforeDenoiseStep(SequentialPipelineBlocks): + block_classes = [ + StableDiffusionXLInputStep, + StableDiffusionXLImg2ImgSetTimestepsStep, + StableDiffusionXLInpaintPrepareLatentsStep, + StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, + ] + block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"] + + @property + def description(self): + return ( + "Before denoise step that prepare the inputs for the denoise step for inpainting task.\n" + + "This is a sequential pipeline blocks:\n" + + " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + + " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" + + " - `StableDiffusionXLInpaintPrepareLatentsStep` is used to prepare the latents\n" + + " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + ) + + +# before_denoise: all task (text2img, img2img, inpainting) +class StableDiffusionXLAutoBeforeDenoiseStep(AutoPipelineBlocks): + block_classes = [ + StableDiffusionXLInpaintBeforeDenoiseStep, + StableDiffusionXLImg2ImgBeforeDenoiseStep, + StableDiffusionXLBeforeDenoiseStep, + ] + block_names = ["inpaint", "img2img", "text2img"] + block_trigger_inputs = ["mask", "image_latents", None] + + @property + def description(self): + return ( + "Before denoise step that prepare the inputs for the denoise step.\n" + + "This is an auto pipeline block that works for text2img, img2img and inpainting tasks as well as controlnet, controlnet_union.\n" + + " - `StableDiffusionXLInpaintBeforeDenoiseStep` (inpaint) is used when both `mask` and `image_latents` are provided.\n" + + " - `StableDiffusionXLImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n" + + " - `StableDiffusionXLBeforeDenoiseStep` (text2img) is used when both `image_latents` and `mask` are not provided.\n" + ) + + +# optional controlnet input step (after before_denoise, before denoise) +# works for both controlnet and controlnet_union +class StableDiffusionXLAutoControlNetInputStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLControlNetUnionInputStep, StableDiffusionXLControlNetInputStep] + block_names = ["controlnet_union", "controlnet"] + block_trigger_inputs = ["control_mode", "control_image"] + + @property + def description(self): + return ( + "Controlnet Input step that prepare the controlnet input.\n" + + "This is an auto pipeline block that works for both controlnet and controlnet_union.\n" + + " (it should be called right before the denoise step)" + + " - `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided.\n" + + " - `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided." + + " - if neither `control_mode` nor `control_image` is provided, step will be skipped." + ) + + +# denoise: controlnet (text2img, img2img, inpainting) +class StableDiffusionXLAutoControlNetDenoiseStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintControlNetDenoiseStep, StableDiffusionXLControlNetDenoiseStep] + block_names = ["inpaint_controlnet_denoise", "controlnet_denoise"] + block_trigger_inputs = ["mask", "controlnet_cond"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents with controlnet. " + "This is a auto pipeline block that using controlnet for text2img, img2img and inpainting tasks." + "This block should not be used without a controlnet_cond input" + " - `StableDiffusionXLInpaintControlNetDenoiseStep` (inpaint_controlnet_denoise) is used when mask is provided." + " - `StableDiffusionXLControlNetDenoiseStep` (controlnet_denoise) is used when mask is not provided but controlnet_cond is provided." + " - If neither mask nor controlnet_cond are provided, step will be skipped." + ) + + +# denoise: all task with or without controlnet (text2img, img2img, inpainting) +class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): + block_classes = [ + StableDiffusionXLAutoControlNetDenoiseStep, + StableDiffusionXLInpaintDenoiseStep, + StableDiffusionXLDenoiseStep, + ] + block_names = ["controlnet_denoise", "inpaint_denoise", "denoise"] + block_trigger_inputs = ["controlnet_cond", "mask", None] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. " + "This is a auto pipeline block that works for text2img, img2img and inpainting tasks. And can be used with or without controlnet." + " - `StableDiffusionXLAutoControlNetDenoiseStep` (controlnet_denoise) is used when controlnet_cond is provided (support controlnet withtext2img, img2img and inpainting tasks)." + " - `StableDiffusionXLInpaintDenoiseStep` (inpaint_denoise) is used when mask is provided (support inpainting tasks)." + " - `StableDiffusionXLDenoiseStep` (denoise) is used when neither mask nor controlnet_cond are provided (support text2img and img2img tasks)." + ) + + +# decode: inpaint +class StableDiffusionXLInpaintDecodeStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLDecodeStep, StableDiffusionXLInpaintOverlayMaskStep] + block_names = ["decode", "mask_overlay"] + + @property + def description(self): + return ( + "Inpaint decode step that decode the denoised latents into images outputs.\n" + + "This is a sequential pipeline blocks:\n" + + " - `StableDiffusionXLDecodeStep` is used to decode the denoised latents into images\n" + + " - `StableDiffusionXLInpaintOverlayMaskStep` is used to overlay the mask on the image" + ) + + +# decode: all task (text2img, img2img, inpainting) +class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintDecodeStep, StableDiffusionXLDecodeStep] + block_names = ["inpaint", "non-inpaint"] + block_trigger_inputs = ["padding_mask_crop", None] + + @property + def description(self): + return ( + "Decode step that decode the denoised latents into images outputs.\n" + + "This is an auto pipeline block that works for inpainting and non-inpainting tasks.\n" + + " - `StableDiffusionXLInpaintDecodeStep` (inpaint) is used when `padding_mask_crop` is provided.\n" + + " - `StableDiffusionXLDecodeStep` (non-inpaint) is used when `padding_mask_crop` is not provided." + ) + + +# ip-adapter, controlnet, text2img, img2img, inpainting +class StableDiffusionXLAutoBlocks(SequentialPipelineBlocks): + block_classes = [ + StableDiffusionXLTextEncoderStep, + StableDiffusionXLAutoIPAdapterStep, + StableDiffusionXLAutoVaeEncoderStep, + StableDiffusionXLAutoBeforeDenoiseStep, + StableDiffusionXLAutoControlNetInputStep, + StableDiffusionXLAutoDenoiseStep, + StableDiffusionXLAutoDecodeStep, + ] + block_names = [ + "text_encoder", + "ip_adapter", + "image_encoder", + "before_denoise", + "controlnet_input", + "denoise", + "decoder", + ] + + @property + def description(self): + return ( + "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion XL.\n" + + "- for image-to-image generation, you need to provide either `image` or `image_latents`\n" + + "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n" + + "- to run the controlnet workflow, you need to provide `control_image`\n" + + "- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n" + + "- to run the ip_adapter workflow, you need to provide `ip_adapter_image`\n" + + "- for text-to-image generation, all you need to provide is `prompt`" + ) + +# controlnet (input + denoise step) +class StableDiffusionXLAutoControlnetStep(SequentialPipelineBlocks): + block_classes = [ + StableDiffusionXLAutoControlNetInputStep, + StableDiffusionXLAutoControlNetDenoiseStep, + ] + block_names = ["controlnet_input", "controlnet_denoise"] + + @property + def description(self): + return ( + "Controlnet auto step that prepare the controlnet input and denoise the latents. " + + "It works for both controlnet and controlnet_union and supports text2img, img2img and inpainting tasks." + + " (it should be replace at 'denoise' step)" + ) + + + +TEXT2IMAGE_BLOCKS = InsertableOrderedDict( + [ + ("text_encoder", StableDiffusionXLTextEncoderStep), + ("input", StableDiffusionXLInputStep), + ("set_timesteps", StableDiffusionXLSetTimestepsStep), + ("prepare_latents", StableDiffusionXLPrepareLatentsStep), + ("prepare_add_cond", StableDiffusionXLPrepareAdditionalConditioningStep), + ("denoise", StableDiffusionXLDenoiseStep), + ("decode", StableDiffusionXLDecodeStep), + ] +) + +IMAGE2IMAGE_BLOCKS = InsertableOrderedDict( + [ + ("text_encoder", StableDiffusionXLTextEncoderStep), + ("image_encoder", StableDiffusionXLVaeEncoderStep), + ("input", StableDiffusionXLInputStep), + ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), + ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep), + ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), + ("denoise", StableDiffusionXLDenoiseStep), + ("decode", StableDiffusionXLDecodeStep), + ] +) + +INPAINT_BLOCKS = InsertableOrderedDict( + [ + ("text_encoder", StableDiffusionXLTextEncoderStep), + ("image_encoder", StableDiffusionXLInpaintVaeEncoderStep), + ("input", StableDiffusionXLInputStep), + ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), + ("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep), + ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), + ("denoise", StableDiffusionXLInpaintDenoiseStep), + ("decode", StableDiffusionXLInpaintDecodeStep), + ] +) + +CONTROLNET_BLOCKS = InsertableOrderedDict( + [ + ("denoise", StableDiffusionXLAutoControlnetStep), + ] +) + + +IP_ADAPTER_BLOCKS = InsertableOrderedDict( + [ + ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), + ] +) + +AUTO_BLOCKS = InsertableOrderedDict( + [ + ("text_encoder", StableDiffusionXLTextEncoderStep), + ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), + ("image_encoder", StableDiffusionXLAutoVaeEncoderStep), + ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), + ("controlnet_input", StableDiffusionXLAutoControlNetInputStep), + ("denoise", StableDiffusionXLAutoDenoiseStep), + ("decode", StableDiffusionXLAutoDecodeStep), + ] +) + + +SDXL_SUPPORTED_BLOCKS = { + "text2img": TEXT2IMAGE_BLOCKS, + "img2img": IMAGE2IMAGE_BLOCKS, + "inpaint": INPAINT_BLOCKS, + "controlnet": CONTROLNET_BLOCKS, + "ip_adapter": IP_ADAPTER_BLOCKS, + "auto": AUTO_BLOCKS, +} diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py deleted file mode 100644 index eee395a860a7..000000000000 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from ...utils import logging -from ..modular_pipeline import SequentialPipelineBlocks -from .before_denoise import StableDiffusionXLAutoBeforeDenoiseStep -from .decoders import StableDiffusionXLAutoDecodeStep -from .denoise import StableDiffusionXLAutoDenoiseStep -from .encoders import ( - StableDiffusionXLAutoIPAdapterStep, - StableDiffusionXLAutoVaeEncoderStep, - StableDiffusionXLTextEncoderStep, -) - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -class StableDiffusionXLAutoPipeline(SequentialPipelineBlocks): - block_classes = [ - StableDiffusionXLTextEncoderStep, - StableDiffusionXLAutoIPAdapterStep, - StableDiffusionXLAutoVaeEncoderStep, - StableDiffusionXLAutoBeforeDenoiseStep, - StableDiffusionXLAutoDenoiseStep, - StableDiffusionXLAutoDecodeStep, - ] - block_names = ["text_encoder", "ip_adapter", "image_encoder", "before_denoise", "denoise", "decoder"] - - @property - def description(self): - return ( - "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion XL.\n" - + "- for image-to-image generation, you need to provide either `image` or `image_latents`\n" - + "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n" - + "- to run the controlnet workflow, you need to provide `control_image`\n" - + "- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n" - + "- to run the ip_adapter workflow, you need to provide `ip_adapter_image`\n" - + "- for text-to-image generation, all you need to provide is `prompt`" - ) From 449f299c633eaa551c98d29c9cdb323e99e81a92 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 26 Jun 2025 12:43:14 +0200 Subject: [PATCH 107/170] move all the sequential pipelines & auto pipelines to the blocks_presets.py --- src/diffusers/__init__.py | 4 +- src/diffusers/modular_pipelines/__init__.py | 4 +- .../stable_diffusion_xl/__init__.py | 26 ++--- .../stable_diffusion_xl/before_denoise.py | 109 ------------------ .../stable_diffusion_xl/decoders.py | 31 ----- .../stable_diffusion_xl/denoise.py | 101 +++++++--------- .../stable_diffusion_xl/encoders.py | 31 +---- .../dummy_torch_and_transformers_objects.py | 2 +- 8 files changed, 59 insertions(+), 249 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 7bb8469c36d1..010bb96045ab 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -358,7 +358,7 @@ else: _import_structure["modular_pipelines"].extend( [ - "StableDiffusionXLAutoPipeline", + "StableDiffusionXLAutoBlocks", "StableDiffusionXLModularLoader", ] ) @@ -979,7 +979,7 @@ from .utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .modular_pipelines import ( - StableDiffusionXLAutoPipeline, + StableDiffusionXLAutoBlocks, StableDiffusionXLModularLoader, ) from .pipelines import ( diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index f6e398268ca0..16f0becd8850 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -39,7 +39,7 @@ "InputParam", "OutputParam", ] - _import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoPipeline", "StableDiffusionXLModularLoader"] + _import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularLoader"] _import_structure["components_manager"] = ["ComponentsManager"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -68,7 +68,7 @@ OutputParam, ) from .stable_diffusion_xl import ( - StableDiffusionXLAutoPipeline, + StableDiffusionXLAutoBlocks, StableDiffusionXLModularLoader, ) else: diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py index 94887aa2791f..7102c76a0a81 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py @@ -21,24 +21,21 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: - _import_structure["decoders"] = ["StableDiffusionXLAutoDecodeStep"] - _import_structure["encoders"] = [ - "StableDiffusionXLAutoIPAdapterStep", - "StableDiffusionXLAutoVaeEncoderStep", - "StableDiffusionXLTextEncoderStep", - ] - _import_structure["modular_block_mappings"] = [ + _import_structure["encoders"] = ["StableDiffusionXLTextEncoderStep"] + _import_structure["modular_blocks_presets"] = [ "AUTO_BLOCKS", "CONTROLNET_BLOCKS", - "CONTROLNET_UNION_BLOCKS", "IMAGE2IMAGE_BLOCKS", "INPAINT_BLOCKS", "IP_ADAPTER_BLOCKS", "SDXL_SUPPORTED_BLOCKS", "TEXT2IMAGE_BLOCKS", + "StableDiffusionXLAutoDecodeStep", + "StableDiffusionXLAutoIPAdapterStep", + "StableDiffusionXLAutoVaeEncoderStep", + "StableDiffusionXLAutoBlocks", ] _import_structure["modular_loader"] = ["StableDiffusionXLModularLoader"] - _import_structure["modular_pipeline_presets"] = ["StableDiffusionXLAutoPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -47,24 +44,23 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: - from .decoders import StableDiffusionXLAutoDecodeStep from .encoders import ( - StableDiffusionXLAutoIPAdapterStep, - StableDiffusionXLAutoVaeEncoderStep, StableDiffusionXLTextEncoderStep, ) - from .modular_block_mappings import ( + from .modular_blocks_presets import ( AUTO_BLOCKS, CONTROLNET_BLOCKS, - CONTROLNET_UNION_BLOCKS, IMAGE2IMAGE_BLOCKS, INPAINT_BLOCKS, IP_ADAPTER_BLOCKS, SDXL_SUPPORTED_BLOCKS, TEXT2IMAGE_BLOCKS, + StableDiffusionXLAutoDecodeStep, + StableDiffusionXLAutoIPAdapterStep, + StableDiffusionXLAutoVaeEncoderStep, + StableDiffusionXLAutoBlocks, ) from .modular_loader import StableDiffusionXLModularLoader - from .modular_pipeline_presets import StableDiffusionXLAutoPipeline else: import sys diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py index be04d2cfd608..33ebf39e5478 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py @@ -26,10 +26,8 @@ from ...utils import logging from ...utils.torch_utils import randn_tensor, unwrap_module from ..modular_pipeline import ( - AutoPipelineBlocks, PipelineBlock, PipelineState, - SequentialPipelineBlocks, ) from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam from .modular_loader import StableDiffusionXLModularLoader @@ -1909,110 +1907,3 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt self.add_block_state(state, block_state) return components, state - - -class StableDiffusionXLControlNetAutoInput(AutoPipelineBlocks): - block_classes = [StableDiffusionXLControlNetUnionInputStep, StableDiffusionXLControlNetInputStep] - block_names = ["controlnet_union", "controlnet"] - block_trigger_inputs = ["control_mode", "control_image"] - - @property - def description(self): - return ( - "Controlnet Input step that prepare the controlnet input.\n" - + "This is an auto pipeline block that works for both controlnet and controlnet_union.\n" - + " - `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided.\n" - + " - `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided." - ) - - -# Before denoise -class StableDiffusionXLBeforeDenoiseStep(SequentialPipelineBlocks): - block_classes = [ - StableDiffusionXLInputStep, - StableDiffusionXLSetTimestepsStep, - StableDiffusionXLPrepareLatentsStep, - StableDiffusionXLPrepareAdditionalConditioningStep, - StableDiffusionXLControlNetAutoInput, - ] - block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond", "controlnet_input"] - - @property - def description(self): - return ( - "Before denoise step that prepare the inputs for the denoise step.\n" - + "This is a sequential pipeline blocks:\n" - + " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" - + " - `StableDiffusionXLSetTimestepsStep` is used to set the timesteps\n" - + " - `StableDiffusionXLPrepareLatentsStep` is used to prepare the latents\n" - + " - `StableDiffusionXLPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" - + " - `StableDiffusionXLControlNetAutoInput` is used to prepare the controlnet input" - ) - - -class StableDiffusionXLImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks): - block_classes = [ - StableDiffusionXLInputStep, - StableDiffusionXLImg2ImgSetTimestepsStep, - StableDiffusionXLImg2ImgPrepareLatentsStep, - StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, - StableDiffusionXLControlNetAutoInput, - ] - block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond", "controlnet_input"] - - @property - def description(self): - return ( - "Before denoise step that prepare the inputs for the denoise step for img2img task.\n" - + "This is a sequential pipeline blocks:\n" - + " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" - + " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" - + " - `StableDiffusionXLImg2ImgPrepareLatentsStep` is used to prepare the latents\n" - + " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" - + " - `StableDiffusionXLControlNetAutoInput` is used to prepare the controlnet input" - ) - - -class StableDiffusionXLInpaintBeforeDenoiseStep(SequentialPipelineBlocks): - block_classes = [ - StableDiffusionXLInputStep, - StableDiffusionXLImg2ImgSetTimestepsStep, - StableDiffusionXLInpaintPrepareLatentsStep, - StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, - StableDiffusionXLControlNetAutoInput, - ] - block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond", "controlnet_input"] - - @property - def description(self): - return ( - "Before denoise step that prepare the inputs for the denoise step for inpainting task.\n" - + "This is a sequential pipeline blocks:\n" - + " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" - + " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" - + " - `StableDiffusionXLInpaintPrepareLatentsStep` is used to prepare the latents\n" - + " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" - + " - `StableDiffusionXLControlNetAutoInput` is used to prepare the controlnet input" - ) - - -class StableDiffusionXLAutoBeforeDenoiseStep(AutoPipelineBlocks): - block_classes = [ - StableDiffusionXLInpaintBeforeDenoiseStep, - StableDiffusionXLImg2ImgBeforeDenoiseStep, - StableDiffusionXLBeforeDenoiseStep, - ] - block_names = ["inpaint", "img2img", "text2img"] - block_trigger_inputs = ["mask", "image_latents", None] - - @property - def description(self): - return ( - "Before denoise step that prepare the inputs for the denoise step.\n" - + "This is an auto pipeline block that works for text2img, img2img and inpainting tasks as well as controlnet, controlnet_union.\n" - + " - `StableDiffusionXLInpaintBeforeDenoiseStep` (inpaint) is used when both `mask` and `image_latents` are provided.\n" - + " - `StableDiffusionXLImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n" - + " - `StableDiffusionXLBeforeDenoiseStep` (text2img) is used when both `image_latents` and `mask` are not provided.\n" - + " - `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided.\n" - + " - `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided." - ) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py index 1f1a8c477de5..12e3e92d292a 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py @@ -24,10 +24,8 @@ from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor from ...utils import logging from ..modular_pipeline import ( - AutoPipelineBlocks, PipelineBlock, PipelineState, - SequentialPipelineBlocks, ) from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam @@ -219,32 +217,3 @@ def __call__(self, components, state: PipelineState) -> PipelineState: self.add_block_state(state, block_state) return components, state - - -class StableDiffusionXLInpaintDecodeStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLDecodeStep, StableDiffusionXLInpaintOverlayMaskStep] - block_names = ["decode", "mask_overlay"] - - @property - def description(self): - return ( - "Inpaint decode step that decode the denoised latents into images outputs.\n" - + "This is a sequential pipeline blocks:\n" - + " - `StableDiffusionXLDecodeStep` is used to decode the denoised latents into images\n" - + " - `StableDiffusionXLInpaintOverlayMaskStep` is used to overlay the mask on the image" - ) - - -class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLInpaintDecodeStep, StableDiffusionXLDecodeStep] - block_names = ["inpaint", "non-inpaint"] - block_trigger_inputs = ["padding_mask_crop", None] - - @property - def description(self): - return ( - "Decode step that decode the denoised latents into images outputs.\n" - + "This is an auto pipeline block that works for inpainting and non-inpainting tasks.\n" - + " - `StableDiffusionXLInpaintDecodeStep` (inpaint) is used when `padding_mask_crop` is provided.\n" - + " - `StableDiffusionXLDecodeStep` (non-inpaint) is used when `padding_mask_crop` is not provided." - ) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py index 9fa439d24420..63c67f436799 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py @@ -23,7 +23,6 @@ from ...schedulers import EulerDiscreteScheduler from ...utils import logging from ..modular_pipeline import ( - AutoPipelineBlocks, BlockState, LoopSequentialPipelineBlocks, PipelineBlock, @@ -49,7 +48,11 @@ def expected_components(self) -> List[ComponentSpec]: @property def description(self) -> str: - return "step within the denoising loop that prepare the latent input for the denoiser. This block should be used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)" + return ( + "step within the denoising loop that prepare the latent input for the denoiser. " + "This block should be used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)" + ) @property def intermediates_inputs(self) -> List[str]: @@ -82,7 +85,10 @@ def expected_components(self) -> List[ComponentSpec]: @property def description(self) -> str: - return "step within the denoising loop that prepare the latent input for the denoiser (for inpainting workflow only). This block should be used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object" + return ( + "step within the denoising loop that prepare the latent input for the denoiser (for inpainting workflow only). " + "This block should be used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object" + ) @property def intermediates_inputs(self) -> List[str]: @@ -155,7 +161,11 @@ def expected_components(self) -> List[ComponentSpec]: @property def description(self) -> str: - return "Step within the denoising loop that denoise the latents with guidance. This block should be used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)" + return ( + "Step within the denoising loop that denoise the latents with guidance. " + "This block should be used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)" + ) @property def inputs(self) -> List[Tuple[str, Any]]: @@ -257,7 +267,11 @@ def expected_components(self) -> List[ComponentSpec]: @property def description(self) -> str: - return "step within the denoising loop that denoise the latents with guidance (with controlnet). This block should be used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)" + return ( + "step within the denoising loop that denoise the latents with guidance (with controlnet). " + "This block should be used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)" + ) @property def inputs(self) -> List[Tuple[str, Any]]: @@ -446,7 +460,11 @@ def expected_components(self) -> List[ComponentSpec]: @property def description(self) -> str: - return "step within the denoising loop that update the latents. This block should be used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)" + return ( + "step within the denoising loop that update the latents. " + "This block should be used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)" + ) @property def inputs(self) -> List[Tuple[str, Any]]: @@ -514,7 +532,11 @@ def expected_components(self) -> List[ComponentSpec]: @property def description(self) -> str: - return "step within the denoising loop that update the latents (for inpainting workflow only). This block should be used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)" + return ( + "step within the denoising loop that update the latents (for inpainting workflow only). " + "This block should be used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)" + ) @property def inputs(self) -> List[Tuple[str, Any]]: @@ -619,7 +641,10 @@ class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks): @property def description(self) -> str: - return "Pipeline block that iteratively denoise the latents over `timesteps`. The specific steps with each iteration can be customized with `blocks` attributes" + return ( + "Pipeline block that iteratively denoise the latents over `timesteps`. " + "The specific steps with each iteration can be customized with `blocks` attributes" + ) @property def loop_expected_components(self) -> List[ComponentSpec]: @@ -679,7 +704,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt # composing the denoising loops -class StableDiffusionXLDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): +class StableDiffusionXLDenoiseStep(StableDiffusionXLDenoiseLoopWrapper): block_classes = [ StableDiffusionXLLoopBeforeDenoiser, StableDiffusionXLLoopDenoiser, @@ -696,11 +721,12 @@ def description(self) -> str: " - `StableDiffusionXLLoopBeforeDenoiser`\n" " - `StableDiffusionXLLoopDenoiser`\n" " - `StableDiffusionXLLoopAfterDenoiser`\n" + "This block supports both text2img and img2img tasks." ) # control_cond -class StableDiffusionXLControlNetDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): +class StableDiffusionXLControlNetDenoiseStep(StableDiffusionXLDenoiseLoopWrapper): block_classes = [ StableDiffusionXLLoopBeforeDenoiser, StableDiffusionXLControlNetLoopDenoiser, @@ -717,11 +743,12 @@ def description(self) -> str: " - `StableDiffusionXLLoopBeforeDenoiser`\n" " - `StableDiffusionXLControlNetLoopDenoiser`\n" " - `StableDiffusionXLLoopAfterDenoiser`\n" + "This block supports using controlnet for both text2img and img2img tasks." ) # mask -class StableDiffusionXLInpaintDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): +class StableDiffusionXLInpaintDenoiseStep(StableDiffusionXLDenoiseLoopWrapper): block_classes = [ StableDiffusionXLInpaintLoopBeforeDenoiser, StableDiffusionXLLoopDenoiser, @@ -738,11 +765,12 @@ def description(self) -> str: " - `StableDiffusionXLInpaintLoopBeforeDenoiser`\n" " - `StableDiffusionXLLoopDenoiser`\n" " - `StableDiffusionXLInpaintLoopAfterDenoiser`\n" + "This block onlysupports inpainting tasks." ) # control_cond + mask -class StableDiffusionXLInpaintControlNetDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): +class StableDiffusionXLInpaintControlNetDenoiseStep(StableDiffusionXLDenoiseLoopWrapper): block_classes = [ StableDiffusionXLInpaintLoopBeforeDenoiser, StableDiffusionXLControlNetLoopDenoiser, @@ -759,52 +787,5 @@ def description(self) -> str: " - `StableDiffusionXLInpaintLoopBeforeDenoiser`\n" " - `StableDiffusionXLControlNetLoopDenoiser`\n" " - `StableDiffusionXLInpaintLoopAfterDenoiser`\n" - ) - - -# all task without controlnet -class StableDiffusionXLDenoiseStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLInpaintDenoiseLoop, StableDiffusionXLDenoiseLoop] - block_names = ["inpaint_denoise", "denoise"] - block_trigger_inputs = ["mask", None] - - @property - def description(self) -> str: - return ( - "Denoise step that iteratively denoise the latents. " - "This is a auto pipeline block that works for text2img, img2img and inpainting tasks." - " - `StableDiffusionXLDenoiseStep` (denoise) is used when no mask is provided." - " - `StableDiffusionXLInpaintDenoiseStep` (inpaint_denoise) is used when mask is provided." - ) - - -# all task with controlnet -class StableDiffusionXLControlNetDenoiseStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLInpaintControlNetDenoiseLoop, StableDiffusionXLControlNetDenoiseLoop] - block_names = ["inpaint_controlnet_denoise", "controlnet_denoise"] - block_trigger_inputs = ["mask", None] - - @property - def description(self) -> str: - return ( - "Denoise step that iteratively denoise the latents with controlnet. " - "This is a auto pipeline block that works for text2img, img2img and inpainting tasks." - " - `StableDiffusionXLControlNetDenoiseStep` (controlnet_denoise) is used when no mask is provided." - " - `StableDiffusionXLInpaintControlNetDenoiseStep` (inpaint_controlnet_denoise) is used when mask is provided." - ) - - -# all task with or without controlnet -class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLControlNetDenoiseStep, StableDiffusionXLDenoiseStep] - block_names = ["controlnet_denoise", "denoise"] - block_trigger_inputs = ["controlnet_cond", None] - - @property - def description(self) -> str: - return ( - "Denoise step that iteratively denoise the latents. " - "This is a auto pipeline block that works for text2img, img2img and inpainting tasks. And can be used with or without controlnet." - " - `StableDiffusionXLDenoiseStep` (denoise) is used when no controlnet_cond is provided (work for text2img, img2img and inpainting tasks)." - " - `StableDiffusionXLControlNetDenoiseStep` (controlnet_denoise) is used when controlnet_cond is provided (work for text2img, img2img and inpainting tasks)." + "This block only supports using controlnet for inpainting tasks." ) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py index d4ec17ada5ea..3561cbf70a82 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py @@ -26,7 +26,7 @@ from ...configuration_utils import FrozenDict from ...guiders import ClassifierFreeGuidance from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import ModularIPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...utils import ( @@ -35,7 +35,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ..modular_pipeline import AutoPipelineBlocks, PipelineBlock, PipelineState +from ..modular_pipeline import PipelineBlock, PipelineState from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam from .modular_loader import StableDiffusionXLModularLoader @@ -893,30 +893,3 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt self.add_block_state(state, block_state) return components, state - - -# auto blocks (YiYi TODO: maybe move all the auto blocks to a separate file) -# Encode -class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLInpaintVaeEncoderStep, StableDiffusionXLVaeEncoderStep] - block_names = ["inpaint", "img2img"] - block_trigger_inputs = ["mask_image", "image"] - - @property - def description(self): - return ( - "Vae encoder step that encode the image inputs into their latent representations.\n" - + "This is an auto pipeline block that works for both inpainting and img2img tasks.\n" - + " - `StableDiffusionXLInpaintVaeEncoderStep` (inpaint) is used when both `mask_image` and `image` are provided.\n" - + " - `StableDiffusionXLVaeEncoderStep` (img2img) is used when only `image` is provided." - ) - - -class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks, ModularIPAdapterMixin): - block_classes = [StableDiffusionXLIPAdapterStep] - block_names = ["ip_adapter"] - block_trigger_inputs = ["ip_adapter_image"] - - @property - def description(self): - return "Run IP Adapter step if `ip_adapter_image` is provided." diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 782df0204f3f..b51cbb34add5 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2,7 +2,7 @@ from ..utils import DummyObject, requires_backends -class StableDiffusionXLAutoPipeline(metaclass=DummyObject): +class StableDiffusionXLAutoBlocks(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): From 7608d2eb9ee847c45c348989016da8e0fa4b972b Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 26 Jun 2025 12:44:02 +0200 Subject: [PATCH 108/170] style --- .../modular_pipelines/stable_diffusion_xl/__init__.py | 4 ++-- .../stable_diffusion_xl/modular_blocks_presets.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py index 7102c76a0a81..fd767609f6f1 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py @@ -30,10 +30,10 @@ "IP_ADAPTER_BLOCKS", "SDXL_SUPPORTED_BLOCKS", "TEXT2IMAGE_BLOCKS", + "StableDiffusionXLAutoBlocks", "StableDiffusionXLAutoDecodeStep", "StableDiffusionXLAutoIPAdapterStep", "StableDiffusionXLAutoVaeEncoderStep", - "StableDiffusionXLAutoBlocks", ] _import_structure["modular_loader"] = ["StableDiffusionXLModularLoader"] @@ -55,10 +55,10 @@ IP_ADAPTER_BLOCKS, SDXL_SUPPORTED_BLOCKS, TEXT2IMAGE_BLOCKS, + StableDiffusionXLAutoBlocks, StableDiffusionXLAutoDecodeStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep, - StableDiffusionXLAutoBlocks, ) from .modular_loader import StableDiffusionXLModularLoader else: diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks_presets.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks_presets.py index fb1b99a7086b..de34c50d43be 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks_presets.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks_presets.py @@ -288,6 +288,7 @@ def description(self): + "- for text-to-image generation, all you need to provide is `prompt`" ) + # controlnet (input + denoise step) class StableDiffusionXLAutoControlnetStep(SequentialPipelineBlocks): block_classes = [ @@ -305,7 +306,6 @@ def description(self): ) - TEXT2IMAGE_BLOCKS = InsertableOrderedDict( [ ("text_encoder", StableDiffusionXLTextEncoderStep), From f63d62e091259fe0085dc0abcb394ee83e8ded60 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 27 Jun 2025 12:48:30 +0200 Subject: [PATCH 109/170] intermediates_inputs -> intermediate_inputs; component_manager -> components_manager, and more --- .../modular_pipelines/components_manager.py | 61 +---- .../modular_pipelines/modular_pipeline.py | 225 +++++++++--------- .../modular_pipeline_utils.py | 38 +-- src/diffusers/modular_pipelines/node_utils.py | 12 +- .../stable_diffusion_xl/__init__.py | 6 +- .../stable_diffusion_xl/before_denoise.py | 38 +-- .../stable_diffusion_xl/decoders.py | 14 +- .../stable_diffusion_xl/denoise.py | 18 +- .../stable_diffusion_xl/encoders.py | 24 +- .../modular_blocks_presets.py | 8 +- .../stable_diffusion_xl/modular_loader.py | 10 + 11 files changed, 222 insertions(+), 232 deletions(-) diff --git a/src/diffusers/modular_pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py index 60d78578940f..df88f9570f7b 100644 --- a/src/diffusers/modular_pipelines/components_manager.py +++ b/src/diffusers/modular_pipelines/components_manager.py @@ -284,12 +284,12 @@ def add(self, name, component, collection: Optional[str] = None): if comp == component: comp_name = self._id_to_name(comp_id) if comp_name == name: - logger.warning(f"component '{name}' already exists as '{comp_id}'") + logger.warning(f"ComponentsManager: component '{name}' already exists as '{comp_id}'") component_id = comp_id break else: logger.warning( - f"Adding component '{name}' as '{component_id}', but it is duplicate of '{comp_id}'" + f"ComponentsManager: adding component '{name}' as '{component_id}', but it is duplicate of '{comp_id}'" f"To remove a duplicate, call `components_manager.remove('')`." ) @@ -301,7 +301,7 @@ def add(self, name, component, collection: Optional[str] = None): if components_with_same_load_id: existing = ", ".join(components_with_same_load_id) logger.warning( - f"Adding component '{component_id}', but it has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. " + f"ComponentsManager: adding component '{component_id}', but it has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. " f"To remove a duplicate, call `components_manager.remove('')`." ) @@ -315,12 +315,12 @@ def add(self, name, component, collection: Optional[str] = None): if component_id not in self.collections[collection]: comp_ids_in_collection = self._lookup_ids(name=name, collection=collection) for comp_id in comp_ids_in_collection: - logger.info(f"Removing existing {name} from collection '{collection}': {comp_id}") + logger.warning(f"ComponentsManager: removing existing {name} from collection '{collection}': {comp_id}") self.remove(comp_id) self.collections[collection].add(component_id) - logger.info(f"Added component '{name}' in collection '{collection}': {component_id}") + logger.info(f"ComponentsManager: added component '{name}' in collection '{collection}': {component_id}") else: - logger.info(f"Added component '{name}' as '{component_id}'") + logger.info(f"ComponentsManager: added component '{name}' as '{component_id}'") if self._auto_offload_enabled: self.enable_auto_cpu_offload(self._auto_offload_device) @@ -659,6 +659,10 @@ def get_model_info( return info def __repr__(self): + # Handle empty components case + if not self.components: + return "Components:\n" + "=" * 50 + "\nNo components registered.\n" + "=" * 50 + # Helper to get simple name without UUID def get_simple_name(name): # Extract the base name by splitting on underscore and taking first part @@ -802,51 +806,6 @@ def format_device(component, info): return output - def from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs): - """ - Load components from a pretrained model and add them to the manager. - - Args: - pretrained_model_name_or_path (str): The path or identifier of the pretrained model - prefix (str, optional): Prefix to add to all component names loaded from this model. - If provided, components will be named as "{prefix}_{component_name}" - **kwargs: Additional arguments to pass to DiffusionPipeline.from_pretrained() - """ - subfolder = kwargs.pop("subfolder", None) - # YiYi TODO: extend AutoModel to support non-diffusers models - if subfolder: - from ..models import AutoModel - - component = AutoModel.from_pretrained(pretrained_model_name_or_path, subfolder=subfolder, **kwargs) - component_name = f"{prefix}_{subfolder}" if prefix else subfolder - if component_name not in self.components: - self.add(component_name, component) - else: - logger.warning( - f"Component '{component_name}' already exists in ComponentsManager and will not be added. To add it, either:\n" - f"1. remove the existing component with remove('{component_name}')\n" - f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')" - ) - else: - from ..pipelines.pipeline_utils import DiffusionPipeline - - pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs) - for name, component in pipe.components.items(): - if component is None: - continue - - # Add prefix if specified - component_name = f"{prefix}_{name}" if prefix else name - - if component_name not in self.components: - self.add(component_name, component) - else: - logger.warning( - f"Component '{component_name}' already exists in ComponentsManager and will not be added. To add it, either:\n" - f"1. remove the existing component with remove('{component_name}')\n" - f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')" - ) - def get_one( self, component_id: Optional[str] = None, diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 9981f87efcaf..657e83f29ebd 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -126,7 +126,7 @@ def get_inputs_kwargs(self, kwargs_type: str) -> Dict[str, Any]: input_names = self.input_kwargs.get(kwargs_type, []) return self.get_inputs(input_names) - def get_intermediates_kwargs(self, kwargs_type: str) -> Dict[str, Any]: + def get_intermediate_kwargs(self, kwargs_type: str) -> Dict[str, Any]: """ Get all intermediates with matching kwargs_type. @@ -325,7 +325,7 @@ def save_pretrained(self, save_directory, push_to_hub=False, **kwargs): def init_pipeline( self, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, - component_manager: Optional[ComponentsManager] = None, + components_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, ): """ @@ -344,10 +344,10 @@ def init_pipeline( loader = loader_class( specs=specs, pretrained_model_name_or_path=pretrained_model_name_or_path, - component_manager=component_manager, + components_manager=components_manager, collection=collection, ) - modular_pipeline = ModularPipeline(blocks=self, loader=loader) + modular_pipeline = ModularPipeline(blocks=deepcopy(self), loader=loader) return modular_pipeline @@ -374,17 +374,17 @@ def inputs(self) -> List[InputParam]: return [] @property - def intermediates_inputs(self) -> List[InputParam]: + def intermediate_inputs(self) -> List[InputParam]: """List of intermediate input parameters. Must be implemented by subclasses.""" return [] @property - def intermediates_outputs(self) -> List[OutputParam]: + def intermediate_outputs(self) -> List[OutputParam]: """List of intermediate output parameters. Must be implemented by subclasses.""" return [] def _get_outputs(self): - return self.intermediates_outputs + return self.intermediate_outputs # YiYi TODO: is it too easy for user to unintentionally override these properties? # Adding outputs attributes here for consistency between PipelineBlock/AutoPipelineBlocks/SequentialPipelineBlocks @@ -403,9 +403,9 @@ def _get_required_inputs(self): def required_inputs(self) -> List[str]: return self._get_required_inputs() - def _get_required_intermediates_inputs(self): + def _get_required_intermediate_inputs(self): input_names = [] - for input_param in self.intermediates_inputs: + for input_param in self.intermediate_inputs: if input_param.required: input_names.append(input_param.name) return input_names @@ -413,8 +413,8 @@ def _get_required_intermediates_inputs(self): # YiYi TODO: maybe we do not need this, it is only used in docstring, # intermediate_inputs is by default required, unless you manually handle it inside the block @property - def required_intermediates_inputs(self) -> List[str]: - return self._get_required_intermediates_inputs() + def required_intermediate_inputs(self) -> List[str]: + return self._get_required_intermediate_inputs() def __call__(self, pipeline, state: PipelineState) -> PipelineState: raise NotImplementedError("__call__ method must be implemented in subclasses") @@ -449,7 +449,7 @@ def __repr__(self): # Intermediates section intermediates_str = format_intermediates_short( - self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs + self.intermediate_inputs, self.required_intermediate_inputs, self.intermediate_outputs ) intermediates = f"Intermediates:\n{intermediates_str}" @@ -459,7 +459,7 @@ def __repr__(self): def doc(self): return make_doc_string( self.inputs, - self.intermediates_inputs, + self.intermediate_inputs, self.outputs, self.description, class_name=self.__class__.__name__, @@ -492,7 +492,7 @@ def get_block_state(self, state: PipelineState) -> dict: data[input_param.kwargs_type][k] = v # Check intermediates - for input_param in self.intermediates_inputs: + for input_param in self.intermediate_inputs: if input_param.name: value = state.get_intermediate(input_param.name) if input_param.required and value is None: @@ -503,9 +503,9 @@ def get_block_state(self, state: PipelineState) -> dict: # if kwargs_type is provided, get all intermediates with matching kwargs_type if input_param.kwargs_type not in data: data[input_param.kwargs_type] = {} - intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type) - if intermediates_kwargs: - for k, v in intermediates_kwargs.items(): + intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type) + if intermediate_kwargs: + for k, v in intermediate_kwargs.items(): if v is not None: if k not in data: data[k] = v @@ -513,13 +513,13 @@ def get_block_state(self, state: PipelineState) -> dict: return BlockState(**data) def add_block_state(self, state: PipelineState, block_state: BlockState): - for output_param in self.intermediates_outputs: + for output_param in self.intermediate_outputs: if not hasattr(block_state, output_param.name): raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") param = getattr(block_state, output_param.name) state.add_intermediate(output_param.name, param, output_param.kwargs_type) - for input_param in self.intermediates_inputs: + for input_param in self.intermediate_inputs: if hasattr(block_state, input_param.name): param = getattr(block_state, input_param.name) # Only add if the value is different from what's in the state @@ -527,7 +527,7 @@ def add_block_state(self, state: PipelineState, block_state: BlockState): if current_value is not param: # Using identity comparison to check if object was modified state.add_intermediate(input_param.name, param, input_param.kwargs_type) - for input_param in self.intermediates_inputs: + for input_param in self.intermediate_inputs: if input_param.name and hasattr(block_state, input_param.name): param = getattr(block_state, input_param.name) # Only add if the value is different from what's in the state @@ -537,8 +537,8 @@ def add_block_state(self, state: PipelineState, block_state: BlockState): elif input_param.kwargs_type: # if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters # we need to first find out which inputs are and loop through them. - intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type) - for param_name, current_value in intermediates_kwargs.items(): + intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type) + for param_name, current_value in intermediate_kwargs.items(): param = getattr(block_state, param_name) if current_value is not param: # Using identity comparison to check if object was modified state.add_intermediate(param_name, param, input_param.kwargs_type) @@ -610,6 +610,7 @@ def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> return list(combined_dict.values()) +# YiYi TODO: change blocks attribute to a different name, so it is not confused with the blocks attribute in ModularPipeline class AutoPipelineBlocks(ModularPipelineBlocks): """ A class that automatically selects a block to run based on the inputs. @@ -692,15 +693,15 @@ def required_inputs(self) -> List[str]: # YiYi TODO: maybe we do not need this, it is only used in docstring, # intermediate_inputs is by default required, unless you manually handle it inside the block @property - def required_intermediates_inputs(self) -> List[str]: + def required_intermediate_inputs(self) -> List[str]: if None not in self.block_trigger_inputs: return [] first_block = next(iter(self.blocks.values())) - required_by_all = set(getattr(first_block, "required_intermediates_inputs", set())) + required_by_all = set(getattr(first_block, "required_intermediate_inputs", set())) # Intersect with required inputs from all other blocks for block in list(self.blocks.values())[1:]: - block_required = set(getattr(block, "required_intermediates_inputs", set())) + block_required = set(getattr(block, "required_intermediate_inputs", set())) required_by_all.intersection_update(block_required) return list(required_by_all) @@ -719,20 +720,20 @@ def inputs(self) -> List[Tuple[str, Any]]: return combined_inputs @property - def intermediates_inputs(self) -> List[str]: - named_inputs = [(name, block.intermediates_inputs) for name, block in self.blocks.items()] + def intermediate_inputs(self) -> List[str]: + named_inputs = [(name, block.intermediate_inputs) for name, block in self.blocks.items()] combined_inputs = combine_inputs(*named_inputs) # mark Required inputs only if that input is required by all the blocks for input_param in combined_inputs: - if input_param.name in self.required_intermediates_inputs: + if input_param.name in self.required_intermediate_inputs: input_param.required = True else: input_param.required = False return combined_inputs @property - def intermediates_outputs(self) -> List[str]: - named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] + def intermediate_outputs(self) -> List[str]: + named_outputs = [(name, block.intermediate_outputs) for name, block in self.blocks.items()] combined_outputs = combine_outputs(*named_outputs) return combined_outputs @@ -885,7 +886,7 @@ def __repr__(self): def doc(self): return make_doc_string( self.inputs, - self.intermediates_inputs, + self.intermediate_inputs, self.outputs, self.description, class_name=self.__class__.__name__, @@ -975,12 +976,12 @@ def required_inputs(self) -> List[str]: # YiYi TODO: maybe we do not need this, it is only used in docstring, # intermediate_inputs is by default required, unless you manually handle it inside the block @property - def required_intermediates_inputs(self) -> List[str]: - required_intermediates_inputs = [] - for input_param in self.intermediates_inputs: + def required_intermediate_inputs(self) -> List[str]: + required_intermediate_inputs = [] + for input_param in self.intermediate_inputs: if input_param.required: - required_intermediates_inputs.append(input_param.name) - return required_intermediates_inputs + required_intermediate_inputs.append(input_param.name) + return required_intermediate_inputs # YiYi TODO: add test for this @property @@ -999,10 +1000,10 @@ def get_inputs(self): return combined_inputs @property - def intermediates_inputs(self) -> List[str]: - return self.get_intermediates_inputs() + def intermediate_inputs(self) -> List[str]: + return self.get_intermediate_inputs() - def get_intermediates_inputs(self): + def get_intermediate_inputs(self): inputs = [] outputs = set() added_inputs = set() @@ -1010,7 +1011,7 @@ def get_intermediates_inputs(self): # Go through all blocks in order for block in self.blocks.values(): # Add inputs that aren't in outputs yet - for inp in block.intermediates_inputs: + for inp in block.intermediate_inputs: if inp.name not in outputs and inp.name not in added_inputs: inputs.append(inp) added_inputs.add(inp.name) @@ -1022,27 +1023,27 @@ def get_intermediates_inputs(self): if should_add_outputs: # Add this block's outputs - block_intermediates_outputs = [out.name for out in block.intermediates_outputs] - outputs.update(block_intermediates_outputs) + block_intermediate_outputs = [out.name for out in block.intermediate_outputs] + outputs.update(block_intermediate_outputs) return inputs @property - def intermediates_outputs(self) -> List[str]: + def intermediate_outputs(self) -> List[str]: named_outputs = [] for name, block in self.blocks.items(): - inp_names = {inp.name for inp in block.intermediates_inputs} - # so we only need to list new variables as intermediates_outputs, but if user wants to list these they modified it's still fine (a.k.a we don't enforce) - # filter out them here so they do not end up as intermediates_outputs + inp_names = {inp.name for inp in block.intermediate_inputs} + # so we only need to list new variables as intermediate_outputs, but if user wants to list these they modified it's still fine (a.k.a we don't enforce) + # filter out them here so they do not end up as intermediate_outputs if name not in inp_names: - named_outputs.append((name, block.intermediates_outputs)) + named_outputs.append((name, block.intermediate_outputs)) combined_outputs = combine_outputs(*named_outputs) return combined_outputs # YiYi TODO: I think we can remove the outputs property @property def outputs(self) -> List[str]: - # return next(reversed(self.blocks.values())).intermediates_outputs - return self.intermediates_outputs + # return next(reversed(self.blocks.values())).intermediate_outputs + return self.intermediate_outputs @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: @@ -1248,7 +1249,7 @@ def __repr__(self): def doc(self): return make_doc_string( self.inputs, - self.intermediates_inputs, + self.intermediate_inputs, self.outputs, self.description, class_name=self.__class__.__name__, @@ -1287,12 +1288,12 @@ def loop_inputs(self) -> List[InputParam]: return [] @property - def loop_intermediates_inputs(self) -> List[InputParam]: + def loop_intermediate_inputs(self) -> List[InputParam]: """List of intermediate input parameters. Must be implemented by subclasses.""" return [] @property - def loop_intermediates_outputs(self) -> List[OutputParam]: + def loop_intermediate_outputs(self) -> List[OutputParam]: """List of intermediate output parameters. Must be implemented by subclasses.""" return [] @@ -1305,9 +1306,9 @@ def loop_required_inputs(self) -> List[str]: return input_names @property - def loop_required_intermediates_inputs(self) -> List[str]: + def loop_required_intermediate_inputs(self) -> List[str]: input_names = [] - for input_param in self.loop_intermediates_inputs: + for input_param in self.loop_intermediate_inputs: if input_param.required: input_names.append(input_param.name) return input_names @@ -1356,25 +1357,25 @@ def get_inputs(self): def inputs(self): return self.get_inputs() - # modified from SequentialPipelineBlocks to include loop_intermediates_inputs + # modified from SequentialPipelineBlocks to include loop_intermediate_inputs @property - def intermediates_inputs(self): - intermediates = self.get_intermediates_inputs() + def intermediate_inputs(self): + intermediates = self.get_intermediate_inputs() intermediate_names = [input.name for input in intermediates] - for loop_intermediate_input in self.loop_intermediates_inputs: + for loop_intermediate_input in self.loop_intermediate_inputs: if loop_intermediate_input.name not in intermediate_names: intermediates.append(loop_intermediate_input) return intermediates # modified from SequentialPipelineBlocks - def get_intermediates_inputs(self): + def get_intermediate_inputs(self): inputs = [] outputs = set() # Go through all blocks in order for block in self.blocks.values(): # Add inputs that aren't in outputs yet - inputs.extend(input_name for input_name in block.intermediates_inputs if input_name.name not in outputs) + inputs.extend(input_name for input_name in block.intermediate_inputs if input_name.name not in outputs) # Only add outputs if the block cannot be skipped should_add_outputs = True @@ -1383,8 +1384,8 @@ def get_intermediates_inputs(self): if should_add_outputs: # Add this block's outputs - block_intermediates_outputs = [out.name for out in block.intermediates_outputs] - outputs.update(block_intermediates_outputs) + block_intermediate_outputs = [out.name for out in block.intermediate_outputs] + outputs.update(block_intermediate_outputs) return inputs # modified from SequentialPipelineBlocks, if any additionan input required by the loop is required by the block @@ -1407,23 +1408,23 @@ def required_inputs(self) -> List[str]: # YiYi TODO: maybe we do not need this, it is only used in docstring, # intermediate_inputs is by default required, unless you manually handle it inside the block @property - def required_intermediates_inputs(self) -> List[str]: - required_intermediates_inputs = [] - for input_param in self.intermediates_inputs: + def required_intermediate_inputs(self) -> List[str]: + required_intermediate_inputs = [] + for input_param in self.intermediate_inputs: if input_param.required: - required_intermediates_inputs.append(input_param.name) - for input_param in self.loop_intermediates_inputs: + required_intermediate_inputs.append(input_param.name) + for input_param in self.loop_intermediate_inputs: if input_param.required: - required_intermediates_inputs.append(input_param.name) - return required_intermediates_inputs + required_intermediate_inputs.append(input_param.name) + return required_intermediate_inputs # YiYi TODO: this need to be thought about more - # modified from SequentialPipelineBlocks to include loop_intermediates_outputs + # modified from SequentialPipelineBlocks to include loop_intermediate_outputs @property - def intermediates_outputs(self) -> List[str]: - named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] + def intermediate_outputs(self) -> List[str]: + named_outputs = [(name, block.intermediate_outputs) for name, block in self.blocks.items()] combined_outputs = combine_outputs(*named_outputs) - for output in self.loop_intermediates_outputs: + for output in self.loop_intermediate_outputs: if output.name not in {output.name for output in combined_outputs}: combined_outputs.append(output) return combined_outputs @@ -1431,7 +1432,7 @@ def intermediates_outputs(self) -> List[str]: # YiYi TODO: this need to be thought about more @property def outputs(self) -> List[str]: - return next(reversed(self.blocks.values())).intermediates_outputs + return next(reversed(self.blocks.values())).intermediate_outputs def __init__(self): blocks = InsertableOrderedDict() @@ -1497,7 +1498,7 @@ def get_block_state(self, state: PipelineState) -> dict: data[input_param.kwargs_type][k] = v # Check intermediates - for input_param in self.intermediates_inputs: + for input_param in self.intermediate_inputs: if input_param.name: value = state.get_intermediate(input_param.name) if input_param.required and value is None: @@ -1508,9 +1509,9 @@ def get_block_state(self, state: PipelineState) -> dict: # if kwargs_type is provided, get all intermediates with matching kwargs_type if input_param.kwargs_type not in data: data[input_param.kwargs_type] = {} - intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type) - if intermediates_kwargs: - for k, v in intermediates_kwargs.items(): + intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type) + if intermediate_kwargs: + for k, v in intermediate_kwargs.items(): if v is not None: if k not in data: data[k] = v @@ -1518,13 +1519,13 @@ def get_block_state(self, state: PipelineState) -> dict: return BlockState(**data) def add_block_state(self, state: PipelineState, block_state: BlockState): - for output_param in self.intermediates_outputs: + for output_param in self.intermediate_outputs: if not hasattr(block_state, output_param.name): raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") param = getattr(block_state, output_param.name) state.add_intermediate(output_param.name, param, output_param.kwargs_type) - for input_param in self.intermediates_inputs: + for input_param in self.intermediate_inputs: if input_param.name and hasattr(block_state, input_param.name): param = getattr(block_state, input_param.name) # Only add if the value is different from what's in the state @@ -1534,8 +1535,8 @@ def add_block_state(self, state: PipelineState, block_state: BlockState): elif input_param.kwargs_type: # if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters # we need to first find out which inputs are and loop through them. - intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type) - for param_name, current_value in intermediates_kwargs.items(): + intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type) + for param_name, current_value in intermediate_kwargs.items(): if not hasattr(block_state, param_name): continue param = getattr(block_state, param_name) @@ -1546,7 +1547,7 @@ def add_block_state(self, state: PipelineState, block_state: BlockState): def doc(self): return make_doc_string( self.inputs, - self.intermediates_inputs, + self.intermediate_inputs, self.outputs, self.description, class_name=self.__class__.__name__, @@ -1660,7 +1661,7 @@ def register_components(self, **kwargs): - non from_pretrained components are created during __init__ and registered as the object itself - Components are updated with the `update()` method: e.g. loader.update(unet=unet) or loader.update(guider=guider_spec) - - (from_pretrained) Components are loaded with the `load()` method: e.g. loader.load(component_names=["unet"]) + - (from_pretrained) Components are loaded with the `load()` method: e.g. loader.load(names=["unet"]) Args: **kwargs: Keyword arguments where keys are component names and values are component objects. @@ -1710,8 +1711,8 @@ def register_components(self, **kwargs): if not is_registered: self.register_to_config(**register_dict) setattr(self, name, module) - if module is not None and module._diffusers_load_id != "null" and self._component_manager is not None: - self._component_manager.add(name, module, self._collection) + if module is not None and module._diffusers_load_id != "null" and self._components_manager is not None: + self._components_manager.add(name, module, self._collection) continue current_module = getattr(self, name, None) @@ -1745,22 +1746,22 @@ def register_components(self, **kwargs): # finally set models setattr(self, name, module) # add to component manager if one is attached - if module is not None and module._diffusers_load_id != "null" and self._component_manager is not None: - self._component_manager.add(name, module, self._collection) + if module is not None and module._diffusers_load_id != "null" and self._components_manager is not None: + self._components_manager.add(name, module, self._collection) # YiYi TODO: add warning for passing multiple ComponentSpec/ConfigSpec with the same name def __init__( self, specs: List[Union[ComponentSpec, ConfigSpec]], pretrained_model_name_or_path: Optional[str] = None, - component_manager: Optional[ComponentsManager] = None, + components_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs, ): """ Initialize the loader with a list of component specs and config specs. """ - self._component_manager = component_manager + self._components_manager = components_manager self._collection = collection self._component_specs = {spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ComponentSpec)} self._config_specs = {spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ConfigSpec)} @@ -1848,6 +1849,10 @@ def dtype(self) -> torch.dtype: return module.dtype return torch.float32 + + @property + def component_names(self) -> List[str]: + return list(self.components.keys()) @property def components(self) -> Dict[str, Any]: @@ -1958,12 +1963,12 @@ def update(self, **kwargs): self.register_to_config(**config_to_register) # YiYi TODO: support map for additional from_pretrained kwargs - def load(self, component_names: Optional[List[str]] = None, **kwargs): + def load(self, names: Optional[List[str]] = None, **kwargs): """ - Load selectedcomponents from specs. + Load selected components from specs. Args: - component_names: List of component names to load + names: List of component names to load **kwargs: additional kwargs to be passed to `from_pretrained()`.Can be: - a single value to be applied to all components to be loaded, e.g. torch_dtype=torch.bfloat16 - a dict, e.g. torch_dtype={"unet": torch.bfloat16, "default": torch.float32} @@ -1971,19 +1976,19 @@ def load(self, component_names: Optional[List[str]] = None, **kwargs): `variant`, `revision`, etc. """ # if not specific name, load all the components with default_creation_method == "from_pretrained" - if component_names is None: - component_names = [ + if names is None: + names = [ name for name in self._component_specs.keys() if self._component_specs[name].default_creation_method == "from_pretrained" ] - elif not isinstance(component_names, list): - component_names = [component_names] + elif not isinstance(names, list): + names = [names] - components_to_load = {name for name in component_names if name in self._component_specs} - unknown_component_names = {name for name in component_names if name not in self._component_specs} - if len(unknown_component_names) > 0: - logger.warning(f"Unknown components will be ignored: {unknown_component_names}") + components_to_load = {name for name in names if name in self._component_specs} + unknown_names = {name for name in names if name not in self._component_specs} + if len(unknown_names) > 0: + logger.warning(f"Unknown components will be ignored: {unknown_names}") components_to_register = {} for name in components_to_load: @@ -2240,7 +2245,7 @@ def from_pretrained( cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], spec_only: bool = True, - component_manager: Optional[ComponentsManager] = None, + components_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs, ): @@ -2261,7 +2266,7 @@ def from_pretrained( elif name in expected_config: config_specs.append(ConfigSpec(name=name, default=value)) - return cls(component_specs + config_specs, component_manager=component_manager, collection=collection) + return cls(component_specs + config_specs, components_manager=components_manager, collection=collection) @staticmethod def _component_spec_to_dict(component_spec: ComponentSpec) -> Any: @@ -2370,20 +2375,20 @@ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = # Add inputs to state, using defaults if not provided in the kwargs or the state # if same input already in the state, will override it if provided in the kwargs - intermediates_inputs = [inp.name for inp in self.blocks.intermediates_inputs] + intermediate_inputs = [inp.name for inp in self.blocks.intermediate_inputs] for expected_input_param in self.blocks.inputs: name = expected_input_param.name default = expected_input_param.default kwargs_type = expected_input_param.kwargs_type if name in passed_kwargs: - if name not in intermediates_inputs: + if name not in intermediate_inputs: state.add_input(name, passed_kwargs.pop(name), kwargs_type) else: state.add_input(name, passed_kwargs[name], kwargs_type) elif name not in state.inputs: state.add_input(name, default, kwargs_type) - for expected_intermediate_param in self.blocks.intermediates_inputs: + for expected_intermediate_param in self.blocks.intermediate_inputs: name = expected_intermediate_param.name kwargs_type = expected_intermediate_param.kwargs_type if name in passed_kwargs: @@ -2412,8 +2417,8 @@ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = else: raise ValueError(f"Output '{output}' is not a valid output type") - def load_components(self, component_names: Optional[List[str]] = None, **kwargs): - self.loader.load(component_names=component_names, **kwargs) + def load_components(self, names: Optional[List[str]] = None, **kwargs): + self.loader.load(names=names, **kwargs) def update_components(self, **kwargs): self.loader.update(**kwargs) @@ -2424,7 +2429,7 @@ def from_pretrained( cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], trust_remote_code: Optional[bool] = None, - component_manager: Optional[ComponentsManager] = None, + components_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs, ): @@ -2432,7 +2437,7 @@ def from_pretrained( pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs ) pipeline = blocks.init_pipeline( - pretrained_model_name_or_path, component_manager=component_manager, collection=collection, **kwargs + pretrained_model_name_or_path, components_manager=components_manager, collection=collection, **kwargs ) return pipeline diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index 015bec89837d..86c017fd6d89 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -49,7 +49,13 @@ def __repr__(self): items = [] for i, (key, value) in enumerate(self.items()): - items.append(f"{i}: ({repr(key)}, {repr(value)})") + if isinstance(value, type): + # For classes, show class name and + obj_repr = f"" + else: + # For objects (instances) and other types, show class name and module + obj_repr = f"" + items.append(f"{i}: ({repr(key)}, {obj_repr})") return "InsertableOrderedDict([\n " + ",\n ".join(items) + "\n])" @@ -260,11 +266,11 @@ class ConfigSpec: description: Optional[str] = None -# YiYi Notes: both inputs and intermediates_inputs are InputParam objects -# however some fields are not relevant for intermediates_inputs +# YiYi Notes: both inputs and intermediate_inputs are InputParam objects +# however some fields are not relevant for intermediate_inputs # e.g. unlike inputs, required only used in docstring for intermediate_inputs, we do not check if a required intermediate inputs is passed -# default is not used for intermediates_inputs, we only use default from inputs, so it is ignored if it is set for intermediates_inputs -# -> should we use different class for inputs and intermediates_inputs? +# default is not used for intermediate_inputs, we only use default from inputs, so it is ignored if it is set for intermediate_inputs +# -> should we use different class for inputs and intermediate_inputs? @dataclass class InputParam: """Specification for an input parameter.""" @@ -324,14 +330,14 @@ def format_inputs_short(inputs): return inputs_str -def format_intermediates_short(intermediates_inputs, required_intermediates_inputs, intermediates_outputs): +def format_intermediates_short(intermediate_inputs, required_intermediate_inputs, intermediate_outputs): """ Formats intermediate inputs and outputs of a block into a string representation. Args: - intermediates_inputs: List of intermediate input parameters - required_intermediates_inputs: List of required intermediate input names - intermediates_outputs: List of intermediate output parameters + intermediate_inputs: List of intermediate input parameters + required_intermediate_inputs: List of required intermediate input names + intermediate_outputs: List of intermediate output parameters Returns: str: Formatted string like: @@ -342,8 +348,8 @@ def format_intermediates_short(intermediates_inputs, required_intermediates_inpu """ # Handle inputs input_parts = [] - for inp in intermediates_inputs: - if inp.name in required_intermediates_inputs: + for inp in intermediate_inputs: + if inp.name in required_intermediate_inputs: input_parts.append(f"Required({inp.name})") else: if inp.name is None and inp.kwargs_type is not None: @@ -353,11 +359,11 @@ def format_intermediates_short(intermediates_inputs, required_intermediates_inpu input_parts.append(inp_name) # Handle modified variables (appear in both inputs and outputs) - inputs_set = {inp.name for inp in intermediates_inputs} + inputs_set = {inp.name for inp in intermediate_inputs} modified_parts = [] new_output_parts = [] - for out in intermediates_outputs: + for out in intermediate_outputs: if out.name in inputs_set: modified_parts.append(out.name) else: @@ -575,7 +581,7 @@ def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines def make_doc_string( inputs, - intermediates_inputs, + intermediate_inputs, outputs, description="", class_name=None, @@ -587,7 +593,7 @@ def make_doc_string( Args: inputs: List of input parameters - intermediates_inputs: List of intermediate input parameters + intermediate_inputs: List of intermediate input parameters outputs: List of output parameters description (str, *optional*): Description of the block class_name (str, *optional*): Name of the class to include in the documentation @@ -621,7 +627,7 @@ def make_doc_string( output += configs_str + "\n\n" # Add inputs section - output += format_input_params(inputs + intermediates_inputs, indent_level=2) + output += format_input_params(inputs + intermediate_inputs, indent_level=2) # Add outputs section output += "\n\n" diff --git a/src/diffusers/modular_pipelines/node_utils.py b/src/diffusers/modular_pipelines/node_utils.py index 7c3c33d3f648..fe7cede459d5 100644 --- a/src/diffusers/modular_pipelines/node_utils.py +++ b/src/diffusers/modular_pipelines/node_utils.py @@ -382,7 +382,7 @@ def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs): # e.g. you can pass ModularNode(scheduler = {name :"scheduler"}) # it will get this spec in node defination {"scheduler": {"label": "Scheduler", "type": "scheduler", "display": "input"}} # name can be a dict, in that case, it is part of a "dict" input in mellon nodes, e.g. text_encoder= {name: {"text_encoders": "text_encoder"}} - inputs = self.blocks.inputs + self.blocks.intermediates_inputs + inputs = self.blocks.inputs + self.blocks.intermediate_inputs for inp in inputs: param = kwargs.pop(inp.name, None) if param: @@ -455,9 +455,9 @@ def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs): output_params = {} if isinstance(self.blocks, SequentialPipelineBlocks): last_block_name = list(self.blocks.blocks.keys())[-1] - outputs = self.blocks.blocks[last_block_name].intermediates_outputs + outputs = self.blocks.blocks[last_block_name].intermediate_outputs else: - outputs = self.blocks.intermediates_outputs + outputs = self.blocks.intermediate_outputs for out in outputs: param = kwargs.pop(out.name, None) @@ -495,9 +495,9 @@ def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs): } self.register_to_config(**register_dict) - def setup(self, components, collection=None): - self.blocks.setup_loader(component_manager=components, collection=collection) - self._components_manager = components + def setup(self, components_manager, collection=None): + self.blocks.setup_loader(components_manager=components_manager, collection=collection) + self._components_manager = components_manager @property def mellon_config(self): diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py index fd767609f6f1..9adb0527958c 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py @@ -28,12 +28,13 @@ "IMAGE2IMAGE_BLOCKS", "INPAINT_BLOCKS", "IP_ADAPTER_BLOCKS", - "SDXL_SUPPORTED_BLOCKS", + "ALL_BLOCKS", "TEXT2IMAGE_BLOCKS", "StableDiffusionXLAutoBlocks", "StableDiffusionXLAutoDecodeStep", "StableDiffusionXLAutoIPAdapterStep", "StableDiffusionXLAutoVaeEncoderStep", + "StableDiffusionXLAutoControlnetStep", ] _import_structure["modular_loader"] = ["StableDiffusionXLModularLoader"] @@ -53,12 +54,13 @@ IMAGE2IMAGE_BLOCKS, INPAINT_BLOCKS, IP_ADAPTER_BLOCKS, - SDXL_SUPPORTED_BLOCKS, + ALL_BLOCKS, TEXT2IMAGE_BLOCKS, StableDiffusionXLAutoBlocks, StableDiffusionXLAutoDecodeStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep, + StableDiffusionXLAutoControlnetStep, ) from .modular_loader import StableDiffusionXLModularLoader else: diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py index 33ebf39e5478..174a2bf58a15 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py @@ -215,7 +215,7 @@ def inputs(self) -> List[InputParam]: ] @property - def intermediates_inputs(self) -> List[str]: + def intermediate_inputs(self) -> List[str]: return [ InputParam( "prompt_embeds", @@ -251,7 +251,7 @@ def intermediates_inputs(self) -> List[str]: ] @property - def intermediates_outputs(self) -> List[str]: + def intermediate_outputs(self) -> List[str]: return [ OutputParam( "batch_size", @@ -423,7 +423,7 @@ def inputs(self) -> List[InputParam]: ] @property - def intermediates_inputs(self) -> List[str]: + def intermediate_inputs(self) -> List[str]: return [ InputParam( "batch_size", @@ -434,7 +434,7 @@ def intermediates_inputs(self) -> List[str]: ] @property - def intermediates_outputs(self) -> List[str]: + def intermediate_outputs(self) -> List[str]: return [ OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), OutputParam( @@ -565,7 +565,7 @@ def inputs(self) -> List[InputParam]: ] @property - def intermediates_outputs(self) -> List[OutputParam]: + def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), OutputParam( @@ -642,7 +642,7 @@ def inputs(self) -> List[Tuple[str, Any]]: ] @property - def intermediates_inputs(self) -> List[str]: + def intermediate_inputs(self) -> List[str]: return [ InputParam("generator"), InputParam( @@ -678,7 +678,7 @@ def intermediates_inputs(self) -> List[str]: ] @property - def intermediates_outputs(self) -> List[str]: + def intermediate_outputs(self) -> List[str]: return [ OutputParam( "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" @@ -928,7 +928,7 @@ def inputs(self) -> List[Tuple[str, Any]]: ] @property - def intermediates_inputs(self) -> List[InputParam]: + def intermediate_inputs(self) -> List[InputParam]: return [ InputParam("generator"), InputParam( @@ -953,7 +953,7 @@ def intermediates_inputs(self) -> List[InputParam]: ] @property - def intermediates_outputs(self) -> List[OutputParam]: + def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam( "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" @@ -1009,7 +1009,7 @@ def inputs(self) -> List[InputParam]: ] @property - def intermediates_inputs(self) -> List[InputParam]: + def intermediate_inputs(self) -> List[InputParam]: return [ InputParam("generator"), InputParam( @@ -1022,7 +1022,7 @@ def intermediates_inputs(self) -> List[InputParam]: ] @property - def intermediates_outputs(self) -> List[OutputParam]: + def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam( "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" @@ -1124,7 +1124,7 @@ def inputs(self) -> List[Tuple[str, Any]]: ] @property - def intermediates_inputs(self) -> List[InputParam]: + def intermediate_inputs(self) -> List[InputParam]: return [ InputParam( "latents", @@ -1147,7 +1147,7 @@ def intermediates_inputs(self) -> List[InputParam]: ] @property - def intermediates_outputs(self) -> List[OutputParam]: + def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam( "add_time_ids", @@ -1328,7 +1328,7 @@ def inputs(self) -> List[Tuple[str, Any]]: ] @property - def intermediates_inputs(self) -> List[InputParam]: + def intermediate_inputs(self) -> List[InputParam]: return [ InputParam( "latents", @@ -1351,7 +1351,7 @@ def intermediates_inputs(self) -> List[InputParam]: ] @property - def intermediates_outputs(self) -> List[OutputParam]: + def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam( "add_time_ids", @@ -1510,7 +1510,7 @@ def inputs(self) -> List[Tuple[str, Any]]: ] @property - def intermediates_inputs(self) -> List[str]: + def intermediate_inputs(self) -> List[str]: return [ InputParam( "latents", @@ -1538,7 +1538,7 @@ def intermediates_inputs(self) -> List[str]: ] @property - def intermediates_outputs(self) -> List[OutputParam]: + def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam("controlnet_cond", type_hint=torch.Tensor, description="The processed control image"), OutputParam( @@ -1730,7 +1730,7 @@ def inputs(self) -> List[Tuple[str, Any]]: ] @property - def intermediates_inputs(self) -> List[InputParam]: + def intermediate_inputs(self) -> List[InputParam]: return [ InputParam( "latents", @@ -1764,7 +1764,7 @@ def intermediates_inputs(self) -> List[InputParam]: ] @property - def intermediates_outputs(self) -> List[OutputParam]: + def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam("controlnet_cond", type_hint=List[torch.Tensor], description="The processed control images"), OutputParam( diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py index 12e3e92d292a..d15fa6f3722b 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py @@ -59,7 +59,7 @@ def inputs(self) -> List[Tuple[str, Any]]: ] @property - def intermediates_inputs(self) -> List[str]: + def intermediate_inputs(self) -> List[str]: return [ InputParam( "latents", @@ -70,7 +70,7 @@ def intermediates_inputs(self) -> List[str]: ] @property - def intermediates_outputs(self) -> List[str]: + def intermediate_outputs(self) -> List[str]: return [ OutputParam( "images", @@ -170,30 +170,28 @@ def description(self) -> str: @property def inputs(self) -> List[Tuple[str, Any]]: return [ - InputParam("image", required=True), - InputParam("mask_image", required=True), + InputParam("image"), + InputParam("mask_image"), InputParam("padding_mask_crop"), ] @property - def intermediates_inputs(self) -> List[str]: + def intermediate_inputs(self) -> List[str]: return [ InputParam( "images", - required=True, type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images from the decode step", ), InputParam( "crops_coords", - required=True, type_hint=Tuple[int, int], description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.", ), ] @property - def intermediates_outputs(self) -> List[str]: + def intermediate_outputs(self) -> List[str]: return [ OutputParam( "images", diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py index 63c67f436799..794bfa297584 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py @@ -55,7 +55,7 @@ def description(self) -> str: ) @property - def intermediates_inputs(self) -> List[str]: + def intermediate_inputs(self) -> List[str]: return [ InputParam( "latents", @@ -91,7 +91,7 @@ def description(self) -> str: ) @property - def intermediates_inputs(self) -> List[str]: + def intermediate_inputs(self) -> List[str]: return [ InputParam( "latents", @@ -174,7 +174,7 @@ def inputs(self) -> List[Tuple[str, Any]]: ] @property - def intermediates_inputs(self) -> List[str]: + def intermediate_inputs(self) -> List[str]: return [ InputParam( "num_inference_steps", @@ -280,7 +280,7 @@ def inputs(self) -> List[Tuple[str, Any]]: ] @property - def intermediates_inputs(self) -> List[str]: + def intermediate_inputs(self) -> List[str]: return [ InputParam( "controlnet_cond", @@ -473,13 +473,13 @@ def inputs(self) -> List[Tuple[str, Any]]: ] @property - def intermediates_inputs(self) -> List[str]: + def intermediate_inputs(self) -> List[str]: return [ InputParam("generator"), ] @property - def intermediates_outputs(self) -> List[OutputParam]: + def intermediate_outputs(self) -> List[OutputParam]: return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] # YiYi TODO: move this out of here @@ -545,7 +545,7 @@ def inputs(self) -> List[Tuple[str, Any]]: ] @property - def intermediates_inputs(self) -> List[str]: + def intermediate_inputs(self) -> List[str]: return [ InputParam("generator"), InputParam( @@ -572,7 +572,7 @@ def intermediates_inputs(self) -> List[str]: ] @property - def intermediates_outputs(self) -> List[OutputParam]: + def intermediate_outputs(self) -> List[OutputParam]: return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] @staticmethod @@ -660,7 +660,7 @@ def loop_expected_components(self) -> List[ComponentSpec]: ] @property - def loop_intermediates_inputs(self) -> List[InputParam]: + def loop_intermediate_inputs(self) -> List[InputParam]: return [ InputParam( "timesteps", diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py index 3561cbf70a82..ff38962c0256 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py @@ -63,8 +63,11 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock): @property def description(self) -> str: return ( - "IP Adapter step that handles all the ip adapter related tasks: Load/unload ip adapter weights into unet, prepare ip adapter image embeddings, etc" - " See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)" + "IP Adapter step that prepares ip adapter image embeddings.\n" + "Note that this step only prepares the embeddings - in order for it to work correctly, " + "you need to load ip adapter weights into unet via ModularPipeline.loader.\n" + "e.g. pipeline.loader.load_ip_adapter() and pipeline.loader.set_ip_adapter_scale().\n" + "See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)" " for more details" ) @@ -99,7 +102,7 @@ def inputs(self) -> List[InputParam]: ] @property - def intermediates_outputs(self) -> List[OutputParam]: + def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"), OutputParam( @@ -251,7 +254,7 @@ def inputs(self) -> List[InputParam]: ] @property - def intermediates_outputs(self) -> List[OutputParam]: + def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam( "prompt_embeds", @@ -602,7 +605,7 @@ def inputs(self) -> List[InputParam]: ] @property - def intermediates_inputs(self) -> List[InputParam]: + def intermediate_inputs(self) -> List[InputParam]: return [ InputParam("generator"), InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), @@ -614,7 +617,7 @@ def intermediates_inputs(self) -> List[InputParam]: ] @property - def intermediates_outputs(self) -> List[OutputParam]: + def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam( "image_latents", @@ -727,14 +730,14 @@ def inputs(self) -> List[InputParam]: ] @property - def intermediates_inputs(self) -> List[InputParam]: + def intermediate_inputs(self) -> List[InputParam]: return [ InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), InputParam("generator"), ] @property - def intermediates_outputs(self) -> List[OutputParam]: + def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam( "image_latents", type_hint=torch.Tensor, description="The latents representation of the input image" @@ -844,6 +847,11 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype block_state.device = components._execution_device + if block_state.height is None: + block_state.height = components.default_height + if block_state.width is None: + block_state.width = components.default_width + if block_state.padding_mask_crop is not None: block_state.crops_coords = components.mask_processor.get_crop_region( block_state.mask_image, block_state.width, block_state.height, pad=block_state.padding_mask_crop diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks_presets.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks_presets.py index de34c50d43be..0ad865544ee5 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks_presets.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks_presets.py @@ -68,7 +68,7 @@ def description(self): ) -# optional ip-adapter (run before before_denoise) +# optional ip-adapter (run before input step) class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks): block_classes = [StableDiffusionXLIPAdapterStep] block_names = ["ip_adapter"] @@ -76,7 +76,9 @@ class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks): @property def description(self): - return "Run IP Adapter step if `ip_adapter_image` is provided." + return ( + "Run IP Adapter step if `ip_adapter_image` is provided. This step should be placed before the 'input' step.\n" + ) # before_denoise: text2img @@ -370,7 +372,7 @@ def description(self): ) -SDXL_SUPPORTED_BLOCKS = { +ALL_BLOCKS = { "text2img": TEXT2IMAGE_BLOCKS, "img2img": IMAGE2IMAGE_BLOCKS, "inpaint": INPAINT_BLOCKS, diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py index 59a723335965..34222444dae3 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py @@ -44,6 +44,16 @@ class StableDiffusionXLModularLoader( StableDiffusionXLLoraLoaderMixin, ModularIPAdapterMixin, ): + + @property + def default_height(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_width(self): + return self.default_sample_size * self.vae_scale_factor + + @property def default_sample_size(self): default_sample_size = 128 From 655512e2cf5ff7eb9c0daa7bf1ac7e970efa62f2 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 28 Jun 2025 08:35:50 +0200 Subject: [PATCH 110/170] components manager: change get -> search_models; add get_ids, get_components_by_ids, get_components_by_names --- .../modular_pipelines/components_manager.py | 331 ++++++++++-------- 1 file changed, 193 insertions(+), 138 deletions(-) diff --git a/src/diffusers/modular_pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py index df88f9570f7b..2e6c288ad9d9 100644 --- a/src/diffusers/modular_pipelines/components_manager.py +++ b/src/diffusers/modular_pipelines/components_manager.py @@ -232,6 +232,8 @@ def search_best_candidate(module_sizes, min_memory_offload): class ComponentsManager: + _available_info_fields = ["model_id", "added_time", "collection", "class_name", "size_gb", "adapters", "has_hook", "execution_device", "ip_adapter"] + def __init__(self): self.components = OrderedDict() self.added_time = OrderedDict() # Store when components were added @@ -239,9 +241,10 @@ def __init__(self): self.model_hooks = None self._auto_offload_enabled = False - def _lookup_ids(self, name=None, collection=None, load_id=None, components: OrderedDict = None): + def _lookup_ids(self, name: Optional[str] = None, collection: Optional[str] = None, load_id: Optional[str] = None, components: Optional[OrderedDict] = None): """ - Lookup component_ids by name, collection, or load_id. + Lookup component_ids by name, collection, or load_id. Does not support pattern matching. + Returns a set of component_ids """ if components is None: components = self.components @@ -351,15 +354,16 @@ def remove(self, component_id: str = None): if torch.cuda.is_available(): torch.cuda.empty_cache() - def get( + # YiYi TODO: rename to search_components for now, may remove this method + def search_components( self, - names: Union[str, List[str]] = None, + names: Optional[str] = None, collection: Optional[str] = None, load_id: Optional[str] = None, - as_name_component_tuples: bool = False, + return_dict_with_names: bool = True, ): """ - Select components by name with simple pattern matching. + Search components by name with simple pattern matching. Optionally filter by collection or load_id. Args: names: Component name(s) or pattern(s) @@ -375,34 +379,48 @@ def get( - "unet*|vae*" : anything with base name starting with "unet" OR starting with "vae" collection: Optional collection to filter by load_id: Optional load_id to filter by - as_name_component_tuples: If True, returns a list of (name, component) tuples using base names - instead of a dictionary with component IDs as keys + return_dict_with_names: If True, returns a dictionary with component names as keys, throw an error if multiple components with the same name are found + If False, returns a dictionary with component IDs as keys Returns: - Dictionary mapping component IDs to components or list of (base_name, component) tuples if - as_name_component_tuples=True + Dictionary mapping component names to components if return_dict_with_names=True, + or a dictionary mapping component IDs to components if return_dict_with_names=False """ + # select components based on collection and load_id filters selected_ids = self._lookup_ids(collection=collection, load_id=load_id) components = {k: self.components[k] for k in selected_ids} + + def get_return_dict(components, return_dict_with_names): + """ + Create a dictionary mapping component names to components if return_dict_with_names=True, + or a dictionary mapping component IDs to components if return_dict_with_names=False, + throw an error if duplicate component names are found when return_dict_with_names=True + """ + if return_dict_with_names: + dict_to_return = {} + for comp_id, comp in components.items(): + comp_name = self._id_to_name(comp_id) + if comp_name in dict_to_return: + raise ValueError(f"Duplicate component names found in the search results: {comp_name}, please set `return_dict_with_names=False` to return a dictionary with component IDs as keys") + dict_to_return[comp_name] = comp + return dict_to_return + else: + return components - # Helper to extract base name from component_id - def get_base_name(component_id): - parts = component_id.split("_") - # If the last part looks like a UUID, remove it - if len(parts) > 1 and len(parts[-1]) >= 8 and "-" in parts[-1]: - return "_".join(parts[:-1]) - return component_id + # if no names are provided, return the filtered components as it is if names is None: - if as_name_component_tuples: - return [(get_base_name(comp_id), comp) for comp_id, comp in components.items()] - else: - return components + return get_return_dict(components, return_dict_with_names) + + # if names is not a string, raise an error + elif not isinstance(names, str): + raise ValueError(f"Invalid type for `names: {type(names)}, only support string") - # Create mapping from component_id to base_name for all components - base_names = {comp_id: get_base_name(comp_id) for comp_id in components.keys()} + # Create mapping from component_id to base_name for components to be used for pattern matching + base_names = {comp_id: self._id_to_name(comp_id) for comp_id in components.keys()} + # Helper function to check if a component matches a pattern based on its base name def matches_pattern(component_id, pattern, exact_match=False): """ Helper function to check if a component matches a pattern based on its base name. @@ -432,113 +450,95 @@ def matches_pattern(component_id, pattern, exact_match=False): else: return pattern == base_name - if isinstance(names, str): - # Check if this is a "not" pattern - is_not_pattern = names.startswith("!") - if is_not_pattern: - names = names[1:] # Remove the ! prefix + # Check if this is a "not" pattern + is_not_pattern = names.startswith("!") + if is_not_pattern: + names = names[1:] # Remove the ! prefix - # Handle OR patterns (containing |) - if "|" in names: - terms = names.split("|") - matches = {} + # Handle OR patterns (containing |) + if "|" in names: + terms = names.split("|") + matches = {} - for comp_id, comp in components.items(): - # For OR patterns with exact names (no wildcards), we do exact matching on base names - exact_match = all(not (term.startswith("*") or term.endswith("*")) for term in terms) - - # Check if any of the terms match this component - should_include = any(matches_pattern(comp_id, term, exact_match) for term in terms) - - # Flip the decision if this is a NOT pattern - if is_not_pattern: - should_include = not should_include - - if should_include: - matches[comp_id] = comp - - log_msg = "NOT " if is_not_pattern else "" - match_type = "exactly matching" if exact_match else "matching any of patterns" - logger.info(f"Getting components {log_msg}{match_type} {terms}: {list(matches.keys())}") - - # Try exact match with a base name - elif any(names == base_name for base_name in base_names.values()): - # Find all components with this base name - matches = { - comp_id: comp - for comp_id, comp in components.items() - if (base_names[comp_id] == names) != is_not_pattern - } + for comp_id, comp in components.items(): + # For OR patterns with exact names (no wildcards), we do exact matching on base names + exact_match = all(not (term.startswith("*") or term.endswith("*")) for term in terms) - if is_not_pattern: - logger.info( - f"Getting all components except those with base name '{names}': {list(matches.keys())}" - ) - else: - logger.info(f"Getting components with base name '{names}': {list(matches.keys())}") + # Check if any of the terms match this component + should_include = any(matches_pattern(comp_id, term, exact_match) for term in terms) - # Prefix match (ends with *) - elif names.endswith("*"): - prefix = names[:-1] - matches = { - comp_id: comp - for comp_id, comp in components.items() - if base_names[comp_id].startswith(prefix) != is_not_pattern - } + # Flip the decision if this is a NOT pattern if is_not_pattern: - logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}") - else: - logger.info(f"Getting components starting with '{prefix}': {list(matches.keys())}") + should_include = not should_include - # Contains match (starts with *) - elif names.startswith("*"): - search = names[1:-1] if names.endswith("*") else names[1:] - matches = { - comp_id: comp - for comp_id, comp in components.items() - if (search in base_names[comp_id]) != is_not_pattern - } - if is_not_pattern: - logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}") - else: - logger.info(f"Getting components containing '{search}': {list(matches.keys())}") - - # Substring match (no wildcards, but not an exact component name) - elif any(names in base_name for base_name in base_names.values()): - matches = { - comp_id: comp - for comp_id, comp in components.items() - if (names in base_names[comp_id]) != is_not_pattern - } - if is_not_pattern: - logger.info(f"Getting components NOT containing '{names}': {list(matches.keys())}") - else: - logger.info(f"Getting components containing '{names}': {list(matches.keys())}") + if should_include: + matches[comp_id] = comp - else: - raise ValueError(f"Component or pattern '{names}' not found in ComponentsManager") + log_msg = "NOT " if is_not_pattern else "" + match_type = "exactly matching" if exact_match else "matching any of patterns" + logger.info(f"Getting components {log_msg}{match_type} {terms}: {list(matches.keys())}") - if not matches: - raise ValueError(f"No components found matching pattern '{names}'") + # Try exact match with a base name + elif any(names == base_name for base_name in base_names.values()): + # Find all components with this base name + matches = { + comp_id: comp + for comp_id, comp in components.items() + if (base_names[comp_id] == names) != is_not_pattern + } - if as_name_component_tuples: - return [(base_names[comp_id], comp) for comp_id, comp in matches.items()] + if is_not_pattern: + logger.info( + f"Getting all components except those with base name '{names}': {list(matches.keys())}" + ) else: - return matches - - elif isinstance(names, list): - results = {} - for name in names: - result = self.get(name, collection, load_id, as_name_component_tuples=False) - results.update(result) - - if as_name_component_tuples: - return [(base_names[comp_id], comp) for comp_id, comp in results.items()] + logger.info(f"Getting components with base name '{names}': {list(matches.keys())}") + + # Prefix match (ends with *) + elif names.endswith("*"): + prefix = names[:-1] + matches = { + comp_id: comp + for comp_id, comp in components.items() + if base_names[comp_id].startswith(prefix) != is_not_pattern + } + if is_not_pattern: + logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}") + else: + logger.info(f"Getting components starting with '{prefix}': {list(matches.keys())}") + + # Contains match (starts with *) + elif names.startswith("*"): + search = names[1:-1] if names.endswith("*") else names[1:] + matches = { + comp_id: comp + for comp_id, comp in components.items() + if (search in base_names[comp_id]) != is_not_pattern + } + if is_not_pattern: + logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}") + else: + logger.info(f"Getting components containing '{search}': {list(matches.keys())}") + + # Substring match (no wildcards, but not an exact component name) + elif any(names in base_name for base_name in base_names.values()): + matches = { + comp_id: comp + for comp_id, comp in components.items() + if (names in base_names[comp_id]) != is_not_pattern + } + if is_not_pattern: + logger.info(f"Getting components NOT containing '{names}': {list(matches.keys())}") else: - return results + logger.info(f"Getting components containing '{names}': {list(matches.keys())}") else: - raise ValueError(f"Invalid type for names: {type(names)}") + raise ValueError(f"Component or pattern '{names}' not found in ComponentsManager") + + if not matches: + raise ValueError(f"No components found matching pattern '{names}'") + + return get_return_dict(matches, return_dict_with_names) def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = "cuda", memory_reserve_margin="3GB"): if not is_accelerate_available(): @@ -582,16 +582,18 @@ def disable_auto_cpu_offload(self): self.model_hooks = None self._auto_offload_enabled = False - # YiYi TODO: add quantization info + # YiYi TODO: (1) add quantization info def get_model_info( - self, component_id: str, fields: Optional[Union[str, List[str]]] = None + self, + component_id: str, + fields: Optional[Union[str, List[str]]] = None, ) -> Optional[Dict[str, Any]]: """Get comprehensive information about a component. Args: component_id: Name of the component to get info for fields: Optional field(s) to return. Can be a string for single field or list of fields. - If None, returns all fields. + If None, uses the available_info_fields setting. Returns: Dictionary containing requested component metadata. If fields is specified, returns only those fields. If a @@ -601,6 +603,14 @@ def get_model_info( raise ValueError(f"Component '{component_id}' not found in ComponentsManager") component = self.components[component_id] + + # Validate fields if specified + if fields is not None: + if isinstance(fields, str): + fields = [fields] + for field in fields: + if field not in self._available_info_fields: + raise ValueError(f"Field '{field}' not found in available_info_fields") # Build complete info dict first info = { @@ -649,15 +659,11 @@ def get_model_info( # If fields specified, filter info if fields is not None: - if isinstance(fields, str): - # Single field requested, return just that value - return {fields: info.get(fields)} - else: - # List of fields requested, return dict with just those fields - return {k: v for k, v in info.items() if k in fields} - - return info - + return {k: v for k, v in info.items() if k in fields} + else: + return info + + # YiYi TODO: (1) add display fields, allow user to set which fields to display in the comnponents table def __repr__(self): # Handle empty components case if not self.components: @@ -814,9 +820,14 @@ def get_one( load_id: Optional[str] = None, ) -> Any: """ - Get a single component by name. Raises an error if multiple components match or none are found. + Get a single component by either: + (1) searching name (pattern matching), collection, or load_id. + (2) passing in a component_id + Raises an error if multiple components match or none are found. + support pattern matching for name Args: + component_id: Optional component ID to get name: Component name or pattern collection: Optional collection to filter by load_id: Optional load_id to filter by @@ -828,15 +839,16 @@ def get_one( ValueError: If no components match or multiple components match """ - # if component_id is provided, return the component if component_id is not None and (name is not None or collection is not None or load_id is not None): - raise ValueError(" if component_id is provided, name, collection, and load_id must be None") - elif component_id is not None: + raise ValueError("If searching by component_id, do not pass name, collection, or load_id") + + # search by component_id + if component_id is not None: if component_id not in self.components: raise ValueError(f"Component '{component_id}' not found in ComponentsManager") return self.components[component_id] - - results = self.get(name, collection, load_id) + # search with name/collection/load_id + results = self.search_components(name, collection, load_id) if not results: raise ValueError(f"No components found matching '{name}'") @@ -845,20 +857,63 @@ def get_one( raise ValueError(f"Multiple components found matching '{name}': {list(results.keys())}") return next(iter(results.values())) + + + def get_ids(self, names: Union[str, List[str]] = None, collection: Optional[str] = None): + """ + Get component IDs by a list of names, optionally filtered by collection. + """ + ids = set() + for name in names: + ids.update(self._lookup_ids(name=name, collection=collection)) + return list(ids) + + def get_components_by_ids(self, ids: List[str], return_dict_with_names: Optional[bool] = True): + """ + Get components by a list of IDs. + """ + components = {id: self.components[id] for id in ids} + if return_dict_with_names: + dict_to_return = {} + for comp_id, comp in components.items(): + comp_name = self._id_to_name(comp_id) + if comp_name in dict_to_return: + raise ValueError(f"Duplicate component names found in the search results: {comp_name}, please set `return_dict_with_names=False` to return a dictionary with component IDs as keys") + dict_to_return[comp_name] = comp + return dict_to_return + else: + return components + + def get_components_by_names(self, names: List[str], collection: Optional[str] = None): + """ + Get components by a list of names, optionally filtered by collection. + """ + ids = self.get_ids(names, collection) + return self.get_components_by_ids(ids) def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]: """Summarizes a dictionary by finding common prefixes that share the same value. - For a dictionary with dot-separated keys like: { + This function is particularly useful for IP-Adapter attention processor patterns, where multiple + attention layers may share the same scale value. It groups dot-separated keys by their values + and finds the shortest common prefix for each group. + + For example, given a dictionary with IP-Adapter attention processor patterns like: + { 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': [0.6], 'down_blocks.1.attentions.1.transformer_blocks.1.attn2.processor': [0.6], 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': [0.3], + 'up_blocks.1.attentions.0.transformer_blocks.1.attn2.processor': [0.3], } - Returns a dictionary where keys are the shortest common prefixes and values are their shared values: { - 'down_blocks': [0.6], 'up_blocks': [0.3] + Returns a dictionary where keys are the shortest common prefixes and values are their shared values: + { + 'down_blocks.1.attentions.1.transformer_blocks': [0.6], + 'up_blocks.1.attentions.0.transformer_blocks': [0.3] } + + This helps identify which attention layers share the same IP-Adapter scale values. """ # First group by values - convert lists to tuples to make them hashable value_to_keys = {} From 885a596696b2d31d7bdfac45b28c96161762ac81 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 28 Jun 2025 08:52:43 +0200 Subject: [PATCH 111/170] blocks -> sub_blocks; will not by default load all; add load_default_components method on modular_pipeline --- .../modular_pipelines/modular_pipeline.py | 168 ++++++++++-------- 1 file changed, 92 insertions(+), 76 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 657e83f29ebd..d1f8b2c3d074 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -626,10 +626,10 @@ class AutoPipelineBlocks(ModularPipelineBlocks): block_trigger_inputs = [] def __init__(self): - blocks = InsertableOrderedDict() + sub_blocks = InsertableOrderedDict() for block_name, block_cls in zip(self.block_names, self.block_classes): - blocks[block_name] = block_cls() - self.blocks = blocks + sub_blocks[block_name] = block_cls() + self.sub_blocks = sub_blocks if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)): raise ValueError( f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same." @@ -646,13 +646,13 @@ def __init__(self): ) # Map trigger inputs to block objects - self.trigger_to_block_map = dict(zip(self.block_trigger_inputs, self.blocks.values())) - self.trigger_to_block_name_map = dict(zip(self.block_trigger_inputs, self.blocks.keys())) - self.block_to_trigger_map = dict(zip(self.blocks.keys(), self.block_trigger_inputs)) + self.trigger_to_block_map = dict(zip(self.block_trigger_inputs, self.sub_blocks.values())) + self.trigger_to_block_name_map = dict(zip(self.block_trigger_inputs, self.sub_blocks.keys())) + self.block_to_trigger_map = dict(zip(self.sub_blocks.keys(), self.block_trigger_inputs)) @property def model_name(self): - return next(iter(self.blocks.values())).model_name + return next(iter(self.sub_blocks.values())).model_name @property def description(self): @@ -661,7 +661,7 @@ def description(self): @property def expected_components(self): expected_components = [] - for block in self.blocks.values(): + for block in self.sub_blocks.values(): for component in block.expected_components: if component not in expected_components: expected_components.append(component) @@ -670,7 +670,7 @@ def expected_components(self): @property def expected_configs(self): expected_configs = [] - for block in self.blocks.values(): + for block in self.sub_blocks.values(): for config in block.expected_configs: if config not in expected_configs: expected_configs.append(config) @@ -680,11 +680,11 @@ def expected_configs(self): def required_inputs(self) -> List[str]: if None not in self.block_trigger_inputs: return [] - first_block = next(iter(self.blocks.values())) + first_block = next(iter(self.sub_blocks.values())) required_by_all = set(getattr(first_block, "required_inputs", set())) # Intersect with required inputs from all other blocks - for block in list(self.blocks.values())[1:]: + for block in list(self.sub_blocks.values())[1:]: block_required = set(getattr(block, "required_inputs", set())) required_by_all.intersection_update(block_required) @@ -696,11 +696,11 @@ def required_inputs(self) -> List[str]: def required_intermediate_inputs(self) -> List[str]: if None not in self.block_trigger_inputs: return [] - first_block = next(iter(self.blocks.values())) + first_block = next(iter(self.sub_blocks.values())) required_by_all = set(getattr(first_block, "required_intermediate_inputs", set())) # Intersect with required inputs from all other blocks - for block in list(self.blocks.values())[1:]: + for block in list(self.sub_blocks.values())[1:]: block_required = set(getattr(block, "required_intermediate_inputs", set())) required_by_all.intersection_update(block_required) @@ -709,7 +709,7 @@ def required_intermediate_inputs(self) -> List[str]: # YiYi TODO: add test for this @property def inputs(self) -> List[Tuple[str, Any]]: - named_inputs = [(name, block.inputs) for name, block in self.blocks.items()] + named_inputs = [(name, block.inputs) for name, block in self.sub_blocks.items()] combined_inputs = combine_inputs(*named_inputs) # mark Required inputs only if that input is required by all the blocks for input_param in combined_inputs: @@ -721,7 +721,7 @@ def inputs(self) -> List[Tuple[str, Any]]: @property def intermediate_inputs(self) -> List[str]: - named_inputs = [(name, block.intermediate_inputs) for name, block in self.blocks.items()] + named_inputs = [(name, block.intermediate_inputs) for name, block in self.sub_blocks.items()] combined_inputs = combine_inputs(*named_inputs) # mark Required inputs only if that input is required by all the blocks for input_param in combined_inputs: @@ -733,13 +733,13 @@ def intermediate_inputs(self) -> List[str]: @property def intermediate_outputs(self) -> List[str]: - named_outputs = [(name, block.intermediate_outputs) for name, block in self.blocks.items()] + named_outputs = [(name, block.intermediate_outputs) for name, block in self.sub_blocks.items()] combined_outputs = combine_outputs(*named_outputs) return combined_outputs @property def outputs(self) -> List[str]: - named_outputs = [(name, block.outputs) for name, block in self.blocks.items()] + named_outputs = [(name, block.outputs) for name, block in self.sub_blocks.items()] combined_outputs = combine_outputs(*named_outputs) return combined_outputs @@ -788,15 +788,15 @@ def fn_recursive_get_trigger(blocks): # Add all non-None values from the trigger inputs list trigger_values.update(t for t in block.block_trigger_inputs if t is not None) - # If block has blocks, recursively check them - if hasattr(block, "blocks"): - nested_triggers = fn_recursive_get_trigger(block.blocks) + # If block has sub_blocks, recursively check them + if hasattr(block, "sub_blocks"): + nested_triggers = fn_recursive_get_trigger(block.sub_blocks) trigger_values.update(nested_triggers) return trigger_values trigger_inputs = set(self.block_trigger_inputs) - trigger_inputs.update(fn_recursive_get_trigger(self.blocks)) + trigger_inputs.update(fn_recursive_get_trigger(self.sub_blocks)) return trigger_inputs @@ -841,7 +841,7 @@ def __repr__(self): # Blocks section - moved to the end with simplified format blocks_str = " Blocks:\n" - for i, (name, block) in enumerate(self.blocks.items()): + for i, (name, block) in enumerate(self.sub_blocks.items()): # Get trigger input for this block trigger = None if hasattr(self, "block_to_trigger_map"): @@ -909,12 +909,12 @@ def description(self): @property def model_name(self): - return next(iter(self.blocks.values())).model_name + return next(iter(self.sub_blocks.values())).model_name @property def expected_components(self): expected_components = [] - for block in self.blocks.values(): + for block in self.sub_blocks.values(): for component in block.expected_components: if component not in expected_components: expected_components.append(component) @@ -923,7 +923,7 @@ def expected_components(self): @property def expected_configs(self): expected_configs = [] - for block in self.blocks.values(): + for block in self.sub_blocks.values(): for config in block.expected_configs: if config not in expected_configs: expected_configs.append(config) @@ -942,32 +942,32 @@ def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "SequentialPipelineBlo instance = cls() # Create instances if classes are provided - blocks = InsertableOrderedDict() + sub_blocks = InsertableOrderedDict() for name, block in blocks_dict.items(): if inspect.isclass(block): - blocks[name] = block() + sub_blocks[name] = block() else: - blocks[name] = block + sub_blocks[name] = block - instance.block_classes = [block.__class__ for block in blocks.values()] - instance.block_names = list(blocks.keys()) - instance.blocks = blocks + instance.block_classes = [block.__class__ for block in sub_blocks.values()] + instance.block_names = list(sub_blocks.keys()) + instance.sub_blocks = sub_blocks return instance def __init__(self): - blocks = InsertableOrderedDict() + sub_blocks = InsertableOrderedDict() for block_name, block_cls in zip(self.block_names, self.block_classes): - blocks[block_name] = block_cls() - self.blocks = blocks + sub_blocks[block_name] = block_cls() + self.sub_blocks = sub_blocks @property def required_inputs(self) -> List[str]: # Get the first block from the dictionary - first_block = next(iter(self.blocks.values())) + first_block = next(iter(self.sub_blocks.values())) required_by_any = set(getattr(first_block, "required_inputs", set())) # Union with required inputs from all other blocks - for block in list(self.blocks.values())[1:]: + for block in list(self.sub_blocks.values())[1:]: block_required = set(getattr(block, "required_inputs", set())) required_by_any.update(block_required) @@ -989,7 +989,7 @@ def inputs(self) -> List[Tuple[str, Any]]: return self.get_inputs() def get_inputs(self): - named_inputs = [(name, block.inputs) for name, block in self.blocks.items()] + named_inputs = [(name, block.inputs) for name, block in self.sub_blocks.items()] combined_inputs = combine_inputs(*named_inputs) # mark Required inputs only if that input is required any of the blocks for input_param in combined_inputs: @@ -1009,7 +1009,7 @@ def get_intermediate_inputs(self): added_inputs = set() # Go through all blocks in order - for block in self.blocks.values(): + for block in self.sub_blocks.values(): # Add inputs that aren't in outputs yet for inp in block.intermediate_inputs: if inp.name not in outputs and inp.name not in added_inputs: @@ -1030,7 +1030,7 @@ def get_intermediate_inputs(self): @property def intermediate_outputs(self) -> List[str]: named_outputs = [] - for name, block in self.blocks.items(): + for name, block in self.sub_blocks.items(): inp_names = {inp.name for inp in block.intermediate_inputs} # so we only need to list new variables as intermediate_outputs, but if user wants to list these they modified it's still fine (a.k.a we don't enforce) # filter out them here so they do not end up as intermediate_outputs @@ -1042,12 +1042,12 @@ def intermediate_outputs(self) -> List[str]: # YiYi TODO: I think we can remove the outputs property @property def outputs(self) -> List[str]: - # return next(reversed(self.blocks.values())).intermediate_outputs + # return next(reversed(self.sub_blocks.values())).intermediate_outputs return self.intermediate_outputs @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: - for block_name, block in self.blocks.items(): + for block_name, block in self.sub_blocks.items(): try: pipeline, state = block(pipeline, state) except Exception as e: @@ -1076,14 +1076,14 @@ def fn_recursive_get_trigger(blocks): # Add all non-None values from the trigger inputs list trigger_values.update(t for t in block.block_trigger_inputs if t is not None) - # If block has blocks, recursively check them - if hasattr(block, "blocks"): - nested_triggers = fn_recursive_get_trigger(block.blocks) + # If block has sub_blocks, recursively check them + if hasattr(block, "sub_blocks"): + nested_triggers = fn_recursive_get_trigger(block.sub_blocks) trigger_values.update(nested_triggers) return trigger_values - return fn_recursive_get_trigger(self.blocks) + return fn_recursive_get_trigger(self.sub_blocks) @property def trigger_inputs(self): @@ -1098,9 +1098,9 @@ def fn_recursive_traverse(block, block_name, active_triggers): # sequential(include loopsequential) or PipelineBlock if not hasattr(block, "block_trigger_inputs"): - if hasattr(block, "blocks"): + if hasattr(block, "sub_blocks"): # sequential or LoopSequentialPipelineBlocks (keep traversing) - for sub_block_name, sub_block in block.blocks.items(): + for sub_block_name, sub_block in block.sub_blocks.items(): blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers) blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers) blocks_to_update = {f"{block_name}.{k}": v for k, v in blocks_to_update.items()} @@ -1128,7 +1128,7 @@ def fn_recursive_traverse(block, block_name, active_triggers): if this_block is not None: # sequential/auto (keep traversing) - if hasattr(this_block, "blocks"): + if hasattr(this_block, "sub_blocks"): result_blocks.update(fn_recursive_traverse(this_block, block_name, active_triggers)) else: # PipelineBlock @@ -1141,7 +1141,7 @@ def fn_recursive_traverse(block, block_name, active_triggers): return result_blocks all_blocks = OrderedDict() - for block_name, block in self.blocks.items(): + for block_name, block in self.sub_blocks.items(): blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers) all_blocks.update(blocks_to_update) return all_blocks @@ -1204,7 +1204,7 @@ def __repr__(self): # Blocks section - moved to the end with simplified format blocks_str = " Blocks:\n" - for i, (name, block) in enumerate(self.blocks.items()): + for i, (name, block) in enumerate(self.sub_blocks.items()): # Get trigger input for this block trigger = None if hasattr(self, "block_to_trigger_map"): @@ -1317,7 +1317,7 @@ def loop_required_intermediate_inputs(self) -> List[str]: @property def expected_components(self): expected_components = [] - for block in self.blocks.values(): + for block in self.sub_blocks.values(): for component in block.expected_components: if component not in expected_components: expected_components.append(component) @@ -1330,7 +1330,7 @@ def expected_components(self): @property def expected_configs(self): expected_configs = [] - for block in self.blocks.values(): + for block in self.sub_blocks.values(): for config in block.expected_configs: if config not in expected_configs: expected_configs.append(config) @@ -1341,7 +1341,7 @@ def expected_configs(self): # modified from SequentialPipelineBlocks to include loop_inputs def get_inputs(self): - named_inputs = [(name, block.inputs) for name, block in self.blocks.items()] + named_inputs = [(name, block.inputs) for name, block in self.sub_blocks.items()] named_inputs.append(("loop", self.loop_inputs)) combined_inputs = combine_inputs(*named_inputs) # mark Required inputs only if that input is required any of the blocks @@ -1373,7 +1373,7 @@ def get_intermediate_inputs(self): outputs = set() # Go through all blocks in order - for block in self.blocks.values(): + for block in self.sub_blocks.values(): # Add inputs that aren't in outputs yet inputs.extend(input_name for input_name in block.intermediate_inputs if input_name.name not in outputs) @@ -1392,14 +1392,14 @@ def get_intermediate_inputs(self): @property def required_inputs(self) -> List[str]: # Get the first block from the dictionary - first_block = next(iter(self.blocks.values())) + first_block = next(iter(self.sub_blocks.values())) required_by_any = set(getattr(first_block, "required_inputs", set())) required_by_loop = set(getattr(self, "loop_required_inputs", set())) required_by_any.update(required_by_loop) # Union with required inputs from all other blocks - for block in list(self.blocks.values())[1:]: + for block in list(self.sub_blocks.values())[1:]: block_required = set(getattr(block, "required_inputs", set())) required_by_any.update(block_required) @@ -1422,7 +1422,7 @@ def required_intermediate_inputs(self) -> List[str]: # modified from SequentialPipelineBlocks to include loop_intermediate_outputs @property def intermediate_outputs(self) -> List[str]: - named_outputs = [(name, block.intermediate_outputs) for name, block in self.blocks.items()] + named_outputs = [(name, block.intermediate_outputs) for name, block in self.sub_blocks.items()] combined_outputs = combine_outputs(*named_outputs) for output in self.loop_intermediate_outputs: if output.name not in {output.name for output in combined_outputs}: @@ -1432,13 +1432,13 @@ def intermediate_outputs(self) -> List[str]: # YiYi TODO: this need to be thought about more @property def outputs(self) -> List[str]: - return next(reversed(self.blocks.values())).intermediate_outputs + return next(reversed(self.sub_blocks.values())).intermediate_outputs def __init__(self): - blocks = InsertableOrderedDict() + sub_blocks = InsertableOrderedDict() for block_name, block_cls in zip(self.block_names, self.block_classes): - blocks[block_name] = block_cls() - self.blocks = blocks + sub_blocks[block_name] = block_cls() + self.sub_blocks = sub_blocks @classmethod def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "LoopSequentialPipelineBlocks": @@ -1454,11 +1454,11 @@ def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "LoopSequentialPipelin instance = cls() instance.block_classes = [block.__class__ for block in blocks_dict.values()] instance.block_names = list(blocks_dict.keys()) - instance.blocks = blocks_dict + instance.sub_blocks = blocks_dict return instance def loop_step(self, components, state: PipelineState, **kwargs): - for block_name, block in self.blocks.items(): + for block_name, block in self.sub_blocks.items(): try: components, state = block(components, state, **kwargs) except Exception as e: @@ -1585,7 +1585,7 @@ def __repr__(self): # Blocks section - moved to the end with simplified format blocks_str = " Blocks:\n" - for i, (name, block) in enumerate(self.blocks.items()): + for i, (name, block) in enumerate(self.sub_blocks.items()): # For SequentialPipelineBlocks, show execution order blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" @@ -1850,9 +1850,22 @@ def dtype(self) -> torch.dtype: return torch.float32 + @property + def null_component_names(self) -> List[str]: + return [name for name in self._component_specs.keys() if hasattr(self, name) and getattr(self, name) is None] + @property def component_names(self) -> List[str]: return list(self.components.keys()) + + @property + def pretrained_component_names(self) -> List[str]: + return [name for name in self._component_specs.keys() if self._component_specs[name].default_creation_method == "from_pretrained"] + + @property + def config_component_names(self) -> List[str]: + return [name for name in self._component_specs.keys() if self._component_specs[name].default_creation_method == "from_config"] + @property def components(self) -> Dict[str, Any]: @@ -1963,27 +1976,23 @@ def update(self, **kwargs): self.register_to_config(**config_to_register) # YiYi TODO: support map for additional from_pretrained kwargs - def load(self, names: Optional[List[str]] = None, **kwargs): + def load(self, names: Union[List[str], str], **kwargs): """ Load selected components from specs. Args: - names: List of component names to load + names: List of component names to load; by default will not load any components **kwargs: additional kwargs to be passed to `from_pretrained()`.Can be: - a single value to be applied to all components to be loaded, e.g. torch_dtype=torch.bfloat16 - a dict, e.g. torch_dtype={"unet": torch.bfloat16, "default": torch.float32} - if potentially override ComponentSpec if passed a different loading field in kwargs, e.g. `repo`, `variant`, `revision`, etc. """ - # if not specific name, load all the components with default_creation_method == "from_pretrained" - if names is None: - names = [ - name - for name in self._component_specs.keys() - if self._component_specs[name].default_creation_method == "from_pretrained" - ] - elif not isinstance(names, list): + # if not pass any names, will not load any components + if isinstance(names, str): names = [names] + elif not isinstance(names, list): + raise ValueError(f"Invalid type for names: {type(names)}") components_to_load = {name for name in names if name in self._component_specs} unknown_names = {name for name in names if name not in self._component_specs} @@ -2308,7 +2317,10 @@ def _component_spec_to_dict(component_spec: ComponentSpec) -> Any: else: lib_name = None cls_name = None - load_spec_dict = {k: getattr(component_spec, k) for k in component_spec.loading_fields()} + if component_spec.default_creation_method == "from_pretrained": + load_spec_dict = {k: getattr(component_spec, k) for k in component_spec.loading_fields()} + else: + load_spec_dict = {} return { "type_hint": (lib_name, cls_name), **load_spec_dict, @@ -2417,7 +2429,11 @@ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = else: raise ValueError(f"Output '{output}' is not a valid output type") - def load_components(self, names: Optional[List[str]] = None, **kwargs): + def load_default_components(self, **kwargs): + names = [name for name in self.loader._component_specs.keys() if self.loader._component_specs[name].default_creation_method == "from_pretrained"] + self.loader.load(names=names, **kwargs) + + def load_components(self, names: Union[List[str], str], **kwargs): self.loader.load(names=names, **kwargs) def update_components(self, **kwargs): From b543bcc661b6befdee022908b05476892f35f878 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 28 Jun 2025 08:53:46 +0200 Subject: [PATCH 112/170] docstring blocks -> sub_blocks --- .../stable_diffusion_xl/denoise.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py index 794bfa297584..7b797e40ee64 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py @@ -50,7 +50,7 @@ def expected_components(self) -> List[ComponentSpec]: def description(self) -> str: return ( "step within the denoising loop that prepare the latent input for the denoiser. " - "This block should be used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " "object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)" ) @@ -87,7 +87,7 @@ def expected_components(self) -> List[ComponentSpec]: def description(self) -> str: return ( "step within the denoising loop that prepare the latent input for the denoiser (for inpainting workflow only). " - "This block should be used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object" + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` object" ) @property @@ -163,7 +163,7 @@ def expected_components(self) -> List[ComponentSpec]: def description(self) -> str: return ( "Step within the denoising loop that denoise the latents with guidance. " - "This block should be used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " "object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)" ) @@ -269,7 +269,7 @@ def expected_components(self) -> List[ComponentSpec]: def description(self) -> str: return ( "step within the denoising loop that denoise the latents with guidance (with controlnet). " - "This block should be used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " "object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)" ) @@ -462,7 +462,7 @@ def expected_components(self) -> List[ComponentSpec]: def description(self) -> str: return ( "step within the denoising loop that update the latents. " - "This block should be used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " "object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)" ) @@ -534,7 +534,7 @@ def expected_components(self) -> List[ComponentSpec]: def description(self) -> str: return ( "step within the denoising loop that update the latents (for inpainting workflow only). " - "This block should be used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " "object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)" ) @@ -643,7 +643,7 @@ class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks): def description(self) -> str: return ( "Pipeline block that iteratively denoise the latents over `timesteps`. " - "The specific steps with each iteration can be customized with `blocks` attributes" + "The specific steps with each iteration can be customized with `sub_blocks` attributes" ) @property @@ -717,7 +717,7 @@ def description(self) -> str: return ( "Denoise step that iteratively denoise the latents. \n" "Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n" - "At each iteration, it runs blocks defined in `blocks` sequencially:\n" + "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" " - `StableDiffusionXLLoopBeforeDenoiser`\n" " - `StableDiffusionXLLoopDenoiser`\n" " - `StableDiffusionXLLoopAfterDenoiser`\n" @@ -739,7 +739,7 @@ def description(self) -> str: return ( "Denoise step that iteratively denoise the latents with controlnet. \n" "Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n" - "At each iteration, it runs blocks defined in `blocks` sequencially:\n" + "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" " - `StableDiffusionXLLoopBeforeDenoiser`\n" " - `StableDiffusionXLControlNetLoopDenoiser`\n" " - `StableDiffusionXLLoopAfterDenoiser`\n" @@ -761,7 +761,7 @@ def description(self) -> str: return ( "Denoise step that iteratively denoise the latents(for inpainting task only). \n" "Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n" - "At each iteration, it runs blocks defined in `blocks` sequencially:\n" + "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" " - `StableDiffusionXLInpaintLoopBeforeDenoiser`\n" " - `StableDiffusionXLLoopDenoiser`\n" " - `StableDiffusionXLInpaintLoopAfterDenoiser`\n" @@ -783,7 +783,7 @@ def description(self) -> str: return ( "Denoise step that iteratively denoise the latents(for inpainting task only) with controlnet. \n" "Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n" - "At each iteration, it runs blocks defined in `blocks` sequencially:\n" + "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" " - `StableDiffusionXLInpaintLoopBeforeDenoiser`\n" " - `StableDiffusionXLControlNetLoopDenoiser`\n" " - `StableDiffusionXLInpaintLoopAfterDenoiser`\n" From 75540f42eebdda4155d9c8e68978422b49280900 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 28 Jun 2025 08:54:05 +0200 Subject: [PATCH 113/170] more blocks -> sub_blocks --- src/diffusers/modular_pipelines/node_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/modular_pipelines/node_utils.py b/src/diffusers/modular_pipelines/node_utils.py index fe7cede459d5..f644ddc9edea 100644 --- a/src/diffusers/modular_pipelines/node_utils.py +++ b/src/diffusers/modular_pipelines/node_utils.py @@ -454,8 +454,8 @@ def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs): output_params = {} if isinstance(self.blocks, SequentialPipelineBlocks): - last_block_name = list(self.blocks.blocks.keys())[-1] - outputs = self.blocks.blocks[last_block_name].intermediate_outputs + last_block_name = list(self.blocks.sub_blocks.keys())[-1] + outputs = self.blocks.sub_blocks[last_block_name].intermediate_outputs else: outputs = self.blocks.intermediate_outputs From 93760b188885344be780465185ea19c6ccd953b5 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 28 Jun 2025 09:15:13 +0200 Subject: [PATCH 114/170] InsertableOrderedDict -> InsertableDict --- .../modular_pipelines/modular_pipeline.py | 16 ++++++++-------- .../modular_pipelines/modular_pipeline_utils.py | 6 +++--- .../modular_blocks_presets.py | 14 +++++++------- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index d1f8b2c3d074..7bb393633934 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -39,7 +39,7 @@ ComponentSpec, ConfigSpec, InputParam, - InsertableOrderedDict, + InsertableDict, OutputParam, format_components, format_configs, @@ -626,7 +626,7 @@ class AutoPipelineBlocks(ModularPipelineBlocks): block_trigger_inputs = [] def __init__(self): - sub_blocks = InsertableOrderedDict() + sub_blocks = InsertableDict() for block_name, block_cls in zip(self.block_names, self.block_classes): sub_blocks[block_name] = block_cls() self.sub_blocks = sub_blocks @@ -840,7 +840,7 @@ def __repr__(self): configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) # Blocks section - moved to the end with simplified format - blocks_str = " Blocks:\n" + blocks_str = " Sub-Blocks:\n" for i, (name, block) in enumerate(self.sub_blocks.items()): # Get trigger input for this block trigger = None @@ -942,7 +942,7 @@ def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "SequentialPipelineBlo instance = cls() # Create instances if classes are provided - sub_blocks = InsertableOrderedDict() + sub_blocks = InsertableDict() for name, block in blocks_dict.items(): if inspect.isclass(block): sub_blocks[name] = block() @@ -955,7 +955,7 @@ def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "SequentialPipelineBlo return instance def __init__(self): - sub_blocks = InsertableOrderedDict() + sub_blocks = InsertableDict() for block_name, block_cls in zip(self.block_names, self.block_classes): sub_blocks[block_name] = block_cls() self.sub_blocks = sub_blocks @@ -1203,7 +1203,7 @@ def __repr__(self): configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) # Blocks section - moved to the end with simplified format - blocks_str = " Blocks:\n" + blocks_str = " Sub-Blocks:\n" for i, (name, block) in enumerate(self.sub_blocks.items()): # Get trigger input for this block trigger = None @@ -1435,7 +1435,7 @@ def outputs(self) -> List[str]: return next(reversed(self.sub_blocks.values())).intermediate_outputs def __init__(self): - sub_blocks = InsertableOrderedDict() + sub_blocks = InsertableDict() for block_name, block_cls in zip(self.block_names, self.block_classes): sub_blocks[block_name] = block_cls() self.sub_blocks = sub_blocks @@ -1584,7 +1584,7 @@ def __repr__(self): configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) # Blocks section - moved to the end with simplified format - blocks_str = " Blocks:\n" + blocks_str = " Sub-Blocks:\n" for i, (name, block) in enumerate(self.sub_blocks.items()): # For SequentialPipelineBlocks, show execution order blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index 86c017fd6d89..37696f5dfac6 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -26,7 +26,7 @@ pass -class InsertableOrderedDict(OrderedDict): +class InsertableDict(OrderedDict): def insert(self, key, value, index): items = list(self.items()) @@ -45,7 +45,7 @@ def insert(self, key, value, index): def __repr__(self): if not self: - return "InsertableOrderedDict()" + return "InsertableDict()" items = [] for i, (key, value) in enumerate(self.items()): @@ -57,7 +57,7 @@ def __repr__(self): obj_repr = f"" items.append(f"{i}: ({repr(key)}, {obj_repr})") - return "InsertableOrderedDict([\n " + ",\n ".join(items) + "\n])" + return "InsertableDict([\n " + ",\n ".join(items) + "\n])" # YiYi TODO: diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks_presets.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks_presets.py index 0ad865544ee5..d28eb5387a46 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks_presets.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks_presets.py @@ -14,7 +14,7 @@ from ...utils import logging from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks -from ..modular_pipeline_utils import InsertableOrderedDict +from ..modular_pipeline_utils import InsertableDict from .before_denoise import ( StableDiffusionXLControlNetInputStep, StableDiffusionXLControlNetUnionInputStep, @@ -308,7 +308,7 @@ def description(self): ) -TEXT2IMAGE_BLOCKS = InsertableOrderedDict( +TEXT2IMAGE_BLOCKS = InsertableDict( [ ("text_encoder", StableDiffusionXLTextEncoderStep), ("input", StableDiffusionXLInputStep), @@ -320,7 +320,7 @@ def description(self): ] ) -IMAGE2IMAGE_BLOCKS = InsertableOrderedDict( +IMAGE2IMAGE_BLOCKS = InsertableDict( [ ("text_encoder", StableDiffusionXLTextEncoderStep), ("image_encoder", StableDiffusionXLVaeEncoderStep), @@ -333,7 +333,7 @@ def description(self): ] ) -INPAINT_BLOCKS = InsertableOrderedDict( +INPAINT_BLOCKS = InsertableDict( [ ("text_encoder", StableDiffusionXLTextEncoderStep), ("image_encoder", StableDiffusionXLInpaintVaeEncoderStep), @@ -346,20 +346,20 @@ def description(self): ] ) -CONTROLNET_BLOCKS = InsertableOrderedDict( +CONTROLNET_BLOCKS = InsertableDict( [ ("denoise", StableDiffusionXLAutoControlnetStep), ] ) -IP_ADAPTER_BLOCKS = InsertableOrderedDict( +IP_ADAPTER_BLOCKS = InsertableDict( [ ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), ] ) -AUTO_BLOCKS = InsertableOrderedDict( +AUTO_BLOCKS = InsertableDict( [ ("text_encoder", StableDiffusionXLTextEncoderStep), ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), From 9aaec5b9bc1d7fb7dee8c39afe97570bf9b57983 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 28 Jun 2025 12:46:06 +0200 Subject: [PATCH 115/170] up --- .../modular_pipelines/components_manager.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/src/diffusers/modular_pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py index 2e6c288ad9d9..3394c67cb00a 100644 --- a/src/diffusers/modular_pipelines/components_manager.py +++ b/src/diffusers/modular_pipelines/components_manager.py @@ -864,6 +864,8 @@ def get_ids(self, names: Union[str, List[str]] = None, collection: Optional[str] Get component IDs by a list of names, optionally filtered by collection. """ ids = set() + if not isinstance(names, list): + names = [names] for name in names: ids.update(self._lookup_ids(name=name, collection=collection)) return list(ids) @@ -895,25 +897,15 @@ def get_components_by_names(self, names: List[str], collection: Optional[str] = def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]: """Summarizes a dictionary by finding common prefixes that share the same value. - This function is particularly useful for IP-Adapter attention processor patterns, where multiple - attention layers may share the same scale value. It groups dot-separated keys by their values - and finds the shortest common prefix for each group. - - For example, given a dictionary with IP-Adapter attention processor patterns like: - { + For a dictionary with dot-separated keys like: { 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': [0.6], 'down_blocks.1.attentions.1.transformer_blocks.1.attn2.processor': [0.6], 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': [0.3], - 'up_blocks.1.attentions.0.transformer_blocks.1.attn2.processor': [0.3], } - Returns a dictionary where keys are the shortest common prefixes and values are their shared values: - { - 'down_blocks.1.attentions.1.transformer_blocks': [0.6], - 'up_blocks.1.attentions.0.transformer_blocks': [0.3] + Returns a dictionary where keys are the shortest common prefixes and values are their shared values: { + 'down_blocks': [0.6], 'up_blocks': [0.3] } - - This helps identify which attention layers share the same IP-Adapter scale values. """ # First group by values - convert lists to tuples to make them hashable value_to_keys = {} From 58dbe0c29eff89242129d8b501754efd3f9a74b7 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 28 Jun 2025 12:46:21 +0200 Subject: [PATCH 116/170] finimsh the quickstart! --- .../en/modular_diffusers/developer_guide.md | 38 +- docs/source/en/modular_diffusers/quicktour.md | 920 ++++++++++++++---- 2 files changed, 723 insertions(+), 235 deletions(-) diff --git a/docs/source/en/modular_diffusers/developer_guide.md b/docs/source/en/modular_diffusers/developer_guide.md index d4a66b067398..a4d8337840fb 100644 --- a/docs/source/en/modular_diffusers/developer_guide.md +++ b/docs/source/en/modular_diffusers/developer_guide.md @@ -39,7 +39,7 @@ Let's see how this works with the Differential Diffusion example. Differential diffusion (https://differential-diffusion.github.io/) is an image-to-image workflow, so it makes sense for us to start with the preset of pipeline blocks used to build img2img pipeline (`IMAGE2IMAGE_BLOCKS`) and see how we can build this new pipeline with them. ```py ->>> IMAGE2IMAGE_BLOCKS = InsertableOrderedDict([ +>>> IMAGE2IMAGE_BLOCKS = InsertableDict([ ... ("text_encoder", StableDiffusionXLTextEncoderStep), ... ("image_encoder", StableDiffusionXLVaeEncoderStep), ... ("input", StableDiffusionXLInputStep), @@ -64,7 +64,7 @@ StableDiffusionXLDenoiseLoop( Description: Denoise step that iteratively denoise the latents. Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method - At each iteration, it runs blocks defined in `blocks` sequencially: + At each iteration, it runs blocks defined in `sub_blocks` sequencially: - `StableDiffusionXLLoopBeforeDenoiser` - `StableDiffusionXLLoopDenoiser` - `StableDiffusionXLLoopAfterDenoiser` @@ -78,13 +78,13 @@ StableDiffusionXLDenoiseLoop( Blocks: [0] before_denoiser (StableDiffusionXLLoopBeforeDenoiser) - Description: step within the denoising loop that prepare the latent input for the denoiser. This block should be used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`) + Description: step within the denoising loop that prepare the latent input for the denoiser. This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`) [1] denoiser (StableDiffusionXLLoopDenoiser) - Description: Step within the denoising loop that denoise the latents with guidance. This block should be used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`) + Description: Step within the denoising loop that denoise the latents with guidance. This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`) [2] after_denoiser (StableDiffusionXLLoopAfterDenoiser) - Description: step within the denoising loop that update the latents. This block should be used to compose the `blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`) + Description: step within the denoising loop that update the latents. This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`) ) ``` @@ -223,7 +223,7 @@ This is the modified `StableDiffusionXLImg2ImgPrepareLatentsStep` we ended up wi ] @property - def intermediates_inputs(self) -> List[InputParam]: + def intermediate_inputs(self) -> List[InputParam]: return [ InputParam("generator"), - InputParam("latent_timestep", required=True, type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step."), @@ -232,7 +232,7 @@ This is the modified `StableDiffusionXLImg2ImgPrepareLatentsStep` we ended up wi ] @property - def intermediates_outputs(self) -> List[OutputParam]: + def intermediate_outputs(self) -> List[OutputParam]: return [ + OutputParam("original_latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"), + OutputParam("diffdiff_masks", type_hint=torch.Tensor, description="The masks used for the differential diffusion denoising process"), @@ -295,7 +295,7 @@ class SDXLDiffDiffLoopBeforeDenoiser(PipelineBlock): + ] @property - def intermediates_inputs(self) -> List[str]: + def intermediate_inputs(self) -> List[str]: return [ InputParam( "latents", @@ -393,7 +393,7 @@ SequentialPipelineBlocks( Description: Step that prepares the additional conditioning for the image-to-image/inpainting generation process [6] denoise (SDXLDiffDiffDenoiseLoop) - Description: Pipeline block that iteratively denoise the latents over `timesteps`. The specific steps with each iteration can be customized with `blocks` attributes + Description: Pipeline block that iteratively denoise the latents over `timesteps`. The specific steps with each iteration can be customized with `sub_blocks` attributes [7] decode (StableDiffusionXLDecodeStep) Description: Step that decodes the denoised latents into images @@ -447,10 +447,10 @@ It has 4 components: `unet` and `guider` are already used in diff-diff, but it a ) ``` -We can directly add the ip-adapter block instance to the `diffdiff_blocks` that we created before. The `blocks` attribute is a `InsertableOrderedDict`, so we're able to insert the it at specific position (index `0` here). +We can directly add the ip-adapter block instance to the `diffdiff_blocks` that we created before. The `sub_blocks` attribute is a `InsertableDict`, so we're able to insert the it at specific position (index `0` here). ```py ->>> dd_blocks.blocks.insert("ip_adapter", ip_adapter_block, 0) +>>> dd_blocks.sub_blocks.insert("ip_adapter", ip_adapter_block, 0) ``` Take a look at the new diff-diff pipeline with ip-adapter! @@ -522,7 +522,7 @@ SequentialPipelineBlocks( Description: Step that prepares the additional conditioning for the image-to-image/inpainting generation process [7] denoise (SDXLDiffDiffDenoiseLoop) - Description: Pipeline block that iteratively denoise the latents over `timesteps`. The specific steps with each iteration can be customized with `blocks` attributes + Description: Pipeline block that iteratively denoise the latents over `timesteps`. The specific steps with each iteration can be customized with `sub_blocks` attributes [8] decode (StableDiffusionXLDecodeStep) Description: Step that decodes the denoised latents into images @@ -535,10 +535,10 @@ Let's test it out. We used an orange image to condition the generation via ip-ad ```py >>> ip_adapter_block = StableDiffusionXLAutoIPAdapterStep() ->>> dd_blocks.blocks.insert("ip_adapter", ip_adapter_block, 0) +>>> dd_blocks.sub_blocks.insert("ip_adapter", ip_adapter_block, 0) >>> >>> dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff") ->>> dd_pipeline.load_components(torch_dtype=torch.float16) +>>> dd_pipeline.load_default_components(torch_dtype=torch.float16) >>> dd_pipeline.loader.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin") >>> dd_pipeline.loader.set_ip_adapter_scale(0.6) >>> dd_pipeline = dd_pipeline.to(device) @@ -627,11 +627,11 @@ StableDiffusionXLControlNetAutoInput( Let's assemble the blocks and run an example using controlnet + differential diffusion. We used a tomato as `control_image`, so you can see that in the output, the right half that transformed into a pear had a tomato-like shape. ```py ->>> dd_blocks.blocks.insert("controlnet_input", control_input_block, 7) ->>> dd_blocks.blocks["denoise"] = controlnet_denoise_block +>>> dd_blocks.sub_blocks.insert("controlnet_input", control_input_block, 7) +>>> dd_blocks.sub_blocks["denoise"] = controlnet_denoise_block >>> >>> dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff") ->>> dd_pipeline.load_components(torch_dtype=torch.float16) +>>> dd_pipeline.load_default_components(torch_dtype=torch.float16) >>> dd_pipeline = dd_pipeline.to(device) >>> >>> control_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/diffdiff_tomato_canny.jpeg") @@ -709,8 +709,8 @@ With a modular repo, it is very easy for the community to use the workflow you j >>> >>> components = ComponentsManager() >>> ->>> diffdiff_pipeline = ModularPipeline.from_pretrained(repo_id, trust_remote_code=True, component_manager=components, collection="diffdiff") ->>> diffdiff_pipeline.loader.load(torch_dtype=torch.float16) +>>> diffdiff_pipeline = ModularPipeline.from_pretrained(repo_id, trust_remote_code=True, components_manager=components, collection="diffdiff") +>>> diffdiff_pipeline.load_default_components(torch_dtype=torch.float16) >>> components.enable_auto_cpu_offload() ``` diff --git a/docs/source/en/modular_diffusers/quicktour.md b/docs/source/en/modular_diffusers/quicktour.md index 94e59b468e27..d9008c5e0c4e 100644 --- a/docs/source/en/modular_diffusers/quicktour.md +++ b/docs/source/en/modular_diffusers/quicktour.md @@ -18,26 +18,32 @@ With Modular Diffusers, we introduce a unified pipeline system that simplifies h **Assemble Like LEGO®**: You can mix and match blocks in flexible ways. This allows you to write dedicated blocks for specific workflows, and then assemble different blocks into a pipeline that that can be used more conveniently for multiple workflows. -In this guide, we will focus on how to use pipeline like this we built with Modular diffusers 🧨! We will also go over the basics of pipeline blocks, how they work under the hood, and how to assemble SequentialPipelineBlocks and AutoPipelineBlocks in this [guide](TODO). For advanced users who want to build complete workflows from scratch, we provide an end-to-end example in the [Developer Guide](developer_guide.md) that covers everything from writing custom pipeline blocks to deploying your workflow as a UI node. +In this guide, we will focus on how to build pipelines this way using blocks we officially support at diffusers 🧨! We will show you how to write your own pipeline blocks and go into more details on how they work under the hood in this [guide](TODO). For advanced users who want to build complete workflows from scratch, we provide an end-to-end example in the [Developer Guide](developer_guide.md) that covers everything from writing custom pipeline blocks to deploying your workflow as a UI node. Let's get started! The Modular Diffusers Framework consists of three main components: +- ModularPipelineBlocks +- PipelineState & BlockState +- ModularPipeline ## ModularPipelineBlocks Pipeline blocks are the fundamental building blocks of the Modular Diffusers system. All pipeline blocks inherit from the base class `ModularPipelineBlocks`, including: -- [`PipelineBlock`](TODO) -- [`SequentialPipelineBlocks`](TODO) -- [`LoopSequentialPipelineBlocks`](TODO) -- [`AutoPipelineBlocks`](TODO) +- [`PipelineBlock`](TODO): The most granular block - you define the computation logic. +- [`SequentialPipelineBlocks`](TODO): A multi-block composed of multiple blocks that run sequentially, passing outputs as inputs to the next block. +- [`LoopSequentialPipelineBlocks`](TODO): A special type of multi-block that forms loops. +- [`AutoPipelineBlocks`](TODO): A multi-block composed of multiple blocks that are selected at runtime based on the inputs. + +All blocks have a consistent interface defining their requirements (components, configs, inputs, outputs) and computation logic. They can be used standalone or combined into larger blocks. Blocks are designed to be assembled into workflows for tasks such as image generation, video creation, and inpainting. + +It is very easy to use a `ModularPipelineBlocks` officially supported in 🧨 Diffusers -To use a `ModularPipelineBlocks` officially supported in 🧨 Diffusers ```py ->>> from diffusers.modular_pipelines.stable_diffusion_xl import StableDiffusionXLTextEncoderStep ->>> text_encoder_block = StableDiffusionXLTextEncoderStep() +from diffusers.modular_pipelines.stable_diffusion_xl import StableDiffusionXLTextEncoderStep +text_encoder_block = StableDiffusionXLTextEncoderStep() ``` -Each [`ModularPipelineBlocks`] defines its requirement for components, configs, inputs, intermediate inputs, and outputs. You'll see that this text encoder block uses 2 text_encoders, 2 tokenizers as well as a guider component. It takes user inputs such as `prompt` and `negative_prompt`, and return text embeddings such as `prompt_embeds` and `negative_prompt_embeds`. +This is a single `PipelineBlock`. You'll see that this text encoder block uses 2 text_encoders, 2 tokenizers as well as a guider component. It takes user inputs such as `prompt` and `negative_prompt`, and return text embeddings outputs such as `prompt_embeds` and `negative_prompt_embeds`. ``` >>> text_encoder_block @@ -59,8 +65,7 @@ StableDiffusionXLTextEncoderStep( ) ``` -More commonly, you can create a `SequentialPipelineBlocks` using a modular blocks preset officially supported in 🧨 Diffusers. - +More commonly, you can create a `SequentialPipelineBlocks` using a block classes preset from 🧨 Diffusers. ```py from diffusers.modular_pipelines import SequentialPipelineBlocks @@ -68,7 +73,7 @@ from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS) ``` -This creates a text-to-image pipeline. +This creates a `SequentialPipelineBlocks`, which is a multi-block composed of other blocks. Unlike single blocks (like the `text_encoder_block` we saw earlier), this multi-block has a `sub_blocks` attribute that contains the sub-blocks (text_encoder, input, set_timesteps, prepare_latents, prepare_added_con, denoise, decode). Its requirements for components, inputs, and intermediate inputs are combined from these blocks that compose it. At runtime, it executes its sub-blocks sequentially and passes the pipeline state from one block to another. ```py >>> t2i_blocks @@ -92,7 +97,7 @@ SequentialPipelineBlocks( Configs: force_zeros_for_empty_prompt (default: True) - Blocks: + Sub-Blocks: [0] text_encoder (StableDiffusionXLTextEncoderStep) Description: Text Encoder step that generate text_embeddings to guide the image generation @@ -114,14 +119,14 @@ SequentialPipelineBlocks( [4] prepare_add_cond (StableDiffusionXLPrepareAdditionalConditioningStep) Description: Step that prepares the additional conditioning for the text-to-image generation process - [5] denoise (StableDiffusionXLDenoiseLoop) + [5] denoise (StableDiffusionXLDenoiseStep) Description: Denoise step that iteratively denoise the latents. Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method - At each iteration, it runs blocks defined in `blocks` sequencially: + At each iteration, it runs blocks defined in `sub_blocks` sequencially: - `StableDiffusionXLLoopBeforeDenoiser` - `StableDiffusionXLLoopDenoiser` - `StableDiffusionXLLoopAfterDenoiser` - + This block supports both text2img and img2img tasks. [6] decode (StableDiffusionXLDecodeStep) Description: Step that decodes the denoised latents into images @@ -129,11 +134,11 @@ SequentialPipelineBlocks( ) ``` -The blocks preset we used (`TEXT2IMAGE_BLOCKS`) is just a dictionary that maps names to ModularPipelineBlocks classes +The block classes preset (`TEXT2IMAGE_BLOCKS`) we used is just a dictionary that maps names to ModularPipelineBlocks classes ```py >>> TEXT2IMAGE_BLOCKS -InsertableOrderedDict([ +InsertableDict([ 0: ('text_encoder', ), 1: ('input', ), 2: ('set_timesteps', ), @@ -144,51 +149,51 @@ InsertableOrderedDict([ ]) ``` -When we create a `SequentialPipelineBlocks` from this preset, it instantiates each class into actual block objects. Its `blocks` attribute contains these instantiated objects: +When we create a `SequentialPipelineBlocks` from this preset, it instantiates each block class into actual block objects. Its `sub_blocks` attribute now contains these instantiated objects: ```py ->>> t2i_blocks.blocks -InsertableOrderedDict([ +>>> t2i_blocks.sub_blocks +InsertableDict([ 0: ('text_encoder', ), 1: ('input', ), 2: ('set_timesteps', ), 3: ('prepare_latents', ), 4: ('prepare_add_cond', ), - 5: ('denoise', ), + 5: ('denoise', ), 6: ('decode', ) ]) ``` -Note that both the preset and the `blocks` attribute are `InsertableOrderedDict` objects, which allows you to modify them in several ways: +Note that both the block classes preset and the `sub_blocks` attribute are `InsertableDict` objects. This is a custom dictionary that extends `OrderedDict` with the ability to insert items at specific positions. You can perform all standard dictionary operations (get, set, delete) plus insert items at any index, which is particularly useful for reordering or inserting blocks in the middle of a pipeline. -**Add a block/block_class at specific positions:** +**Add a block:** ```py -# Add to preset (class) +# Add a block class to the preset BLOCKS.insert("block_name", BlockClass, index) -# Add to blocks attribute (instance) -t2i_blocks.blocks.insert("block_name", block_instance, index) +# Add a block instance to the `sub_blocks` attribute +t2i_blocks.sub_blocks.insert("block_name", block_instance, index) ``` -**Remove blocks:** +**Remove a block:** ```py # remove a block class from preset BLOCKS.pop("text_encoder") # split out a block instance on its own -text_encoder_block = t2i_blocks.blocks.pop("text_encoder") +text_encoder_block = t2i_blocks.sub_blocks.pop("text_encoder") ``` -**Swap/replace blocks:** +**Swap block:** ```py -# Replace in preset (class) +# Replace block class in preset BLOCKS["prepare_latents"] = CustomPrepareLatents -# Replace in blocks attribute (instance) -t2i_blocks.blocks["prepare_latents"] = CustomPrepareLatents() +# Replace in sub_blocks attribute +t2i_blocks.sub_blocks["prepare_latents"] = CustomPrepareLatents() ``` This means you can mix-and-match blocks in very flexible ways. Let's see some real examples: -**Example 1: Adding IP-Adapter to the preset** -Let's insert IP-Adapter at index 0 (before the text_encoder block) to create a text-to-image pipeline with IP-Adapter support: +**Example 1: Adding IP-Adapter to the Block Classes Preset** +Let's make a new block classes preset by insert IP-Adapter at index 0 (before the text_encoder block), and create a text-to-image pipeline with IP-Adapter support: ```py from diffusers.modular_pipelines.stable_diffusion_xl import StableDiffusionXLAutoIPAdapterStep @@ -197,31 +202,16 @@ CUSTOM_BLOCKS.insert("ip_adapter", StableDiffusionXLAutoIPAdapterStep, 0) custom_blocks = SequentialPipelineBlocks.from_blocks_dict(CUSTOM_BLOCKS) ``` -**Example 2: Extracting a block from the pipeline** -You can extract a block instance from the pipeline to use it independently. A common pattern is to extract the text_encoder to process prompts once, then reuse the text embeddings to generate multiple images with different settings (schedulers, seeds, inference steps). +**Example 2: Extracting a block from a multi-block** +You can extract a block instance from the multi-block to use it independently. A common pattern is to use text_encoder to process prompts once, then reuse the text embeddings outputs to generate multiple images with different settings (schedulers, seeds, inference steps). We can do this by simply extracting the text_encoder block from the pipeline. ```py ->>> text_encoder_blocks = t2i_blocks.blocks.pop("text_encoder") +# this gives you StableDiffusionXLTextEncoderStep() +>>> text_encoder_blocks = t2i_blocks.sub_blocks.pop("text_encoder") >>> text_encoder_blocks -StableDiffusionXLTextEncoderStep( - Class: PipelineBlock - Description: Text Encoder step that generate text_embeddings to guide the image generation - Components: - text_encoder (`CLIPTextModel`) - text_encoder_2 (`CLIPTextModelWithProjection`) - tokenizer (`CLIPTokenizer`) - tokenizer_2 (`CLIPTokenizer`) - guider (`ClassifierFreeGuidance`) - Configs: - force_zeros_for_empty_prompt (default: True) - Inputs: - prompt=None, prompt_2=None, negative_prompt=None, negative_prompt_2=None, cross_attention_kwargs=None, clip_skip=None - Intermediates: - - outputs: prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds -) ``` -the pipeline now has fewer components and no longer has the `text_encoder` block: +the multi-block now has fewer components and no longer has the `text_encoder` block. If you check its docstring `t2i_blocks.doc`, you will see that it no longer accepts `prompt` as input - you will need to pass the embeddings instead. ```py >>> t2i_blocks @@ -271,6 +261,33 @@ SequentialPipelineBlocks( ) ``` + + +💡 You can find all the block classes presets we support for each model in `ALL_BLOCKS`. + +```py +# For Stable Diffusion XL +from diffusers.modular_pipelines.stable_diffusion_xl import ALL_BLOCKS +ALL_BLOCKS +# For other models... +from diffusers.modular_pipelines. import ALL_BLOCKS +``` + +Each model provides a dictionary that maps all supported tasks/techniques to their corresponding block classes presets. For SDXL, it is + +```py +ALL_BLOCKS = { + "text2img": TEXT2IMAGE_BLOCKS, + "img2img": IMAGE2IMAGE_BLOCKS, + "inpaint": INPAINT_BLOCKS, + "controlnet": CONTROLNET_BLOCKS, + "ip_adapter": IP_ADAPTER_BLOCKS, + "auto": AUTO_BLOCKS, +} +``` + + + We will not go over how to write your own ModularPipelineBlocks but you can learn more about it [here](TODO). This covers the essentials of pipeline blocks! You may have noticed that we haven't discussed how to load or run pipeline blocks - that's because **pipeline blocks are not runnable by themselves**. They are essentially **"definitions"** - they define the specifications and computational steps for a pipeline, but they do not contain any model states. To actually run them, you need to convert them into a `ModularPipeline` object. @@ -306,8 +323,8 @@ In standard `model_index.json`, each component entry is a `(library, class)` tup In `modular_model_index.json`, each component entry contains 3 elements: `(library, class, loading_specs {})` -- `library` and `class`: Information about the actual component loaded in the pipeline at the time of saving (can be `None` if not loaded) -- `loading_specs`: A dictionary containing all information required to load this component, including `repo`, `revision`, `subfolder`, `variant`, and `type_hint` +- `library` and `class`: Information about the actual component loaded in the pipeline at the time of saving (will be `null` if not loaded) +- `loading_specs`: A dictionary containing all information required to load this component, including `repo`, `revision`, `subfolder`, `variant`, and `type_hint`. ```py "text_encoder": [ @@ -325,7 +342,20 @@ In `modular_model_index.json`, each component entry contains 3 elements: `(libra } ], ``` +Some components may not have `repo` field, they cannot be loaded from a repository and can only be created with default config from the pipeline +```py + "image_processor": [ + "diffusers", + "VaeImageProcessor", + { + "type_hint": [ + "diffusers", + "VaeImageProcessor" + ] + } + ], +``` Unlike standard repositories where components must be in subfolders within the same repo, modular repositories can fetch components from different repositories based on the `loading_specs` dictionary. e.g. the `text_encoder` component will be fetched from the "text_encoder" folder in `stabilityai/stable-diffusion-xl-base-1.0` while other components come from different repositories. @@ -387,19 +417,33 @@ Unlike `DiffusionPipeline`, when you create a `ModularPipeline` instance (whethe ```py # This will load ALL the expected components into pipeline -t2i_pipeline.load_components(torch_dtype=torch.float16) -t2i_pipeline.to(device) +import torch +t2i_pipeline.load_default_components(torch_dtype=torch.float16) +t2i_pipeline.to("cuda") ``` -All expected components are now loaded into the pipeline. You can also partially load specific components using the `component_names` argument. For example, to only load unet and vae: +All expected components are now loaded into the pipeline. You can also partially load specific components using the `names` argument. For example, to only load unet and vae: ```py ->>> t2i_pipeline.load_components(component_names=["unet", "vae"]) +>>> t2i_pipeline.load_components(names=["unet", "vae"], torch_dtype=torch.float16) ``` -You can inspect the pipeline's loading status through its `loader` attribute to understand what components are expected to load, which ones are already loaded, how they were loaded, and what loading specs are available. It has the same structure as the `modular_model_index.json` we discussed earlier - each component entry contains the `(library, class, loading_specs)` format. You'll need to understand that structure to properly read the loading status below. +You can inspect the pipeline's loading status through its `loader` attribute to understand what components are expected to load, which ones are already loaded, how they were loaded, and what loading specs are available. The loader is synced with the `modular_model_index.json` from the repository you used during `init_pipeline()` - it takes the loading specs that match the pipeline's component requirements. + +For example, if your pipeline needs a `text_encoder` component, the loader will include the loading spec for `text_encoder` from the modular repo. If the pipeline doesn't need a component (like `controlnet` in a basic text-to-image pipeline), that component won't appear in the loader even if it exists in the modular repo. + +The loader has the same structure as `modular_model_index.json` - each component entry contains the `(library, class, loading_specs)` format. You'll need to understand that structure to properly read the loading status below. + + + +💡 **How to read the loader**: +- **`library` and `class` fields**: Show info about actually loaded components. If `null`, the component is not loaded yet. +- **`loading_specs`**: If it does not have `repo` field or if it is `null`, the component cannot be loaded from a repository and can only be created with default config by the pipeline. + + + +Let's inspect the `t2i_pipeline.loader`, you can see all the components expected to load are listed as entries in the loader. The `guider` and `image_processor` components were created using default config (their `library` and `class` field are populated, this means they are initialized, but their loading spec dict is missing loading related fields). The `vae` and `unet` components were loaded using their respective loading specs. The rest of the components (scheduler, text_encoder, text_encoder_2, tokenizer, tokenizer_2) are not loaded yet (their `library`, `class` fields are `null`), but you can examine their loading specs to see where they would be loaded from when you call `load_components()`. -Let's inspect the `t2i_pipeline`, you can see all the components expected to load are listed as entries in the loader. The `guider` and `image_processor` components were created using default config (their `library` and `class` field are populated, this means they are initialized, but `loading_spec["repo"]` is null). The `vae` and `unet` components were loaded using their respective loading specs. The rest of the components (scheduler, text_encoder, text_encoder_2, tokenizer, tokenizer_2) are not loaded yet (their `library`, `class` fields are `null`), but you can examine their loading specs to see where they would be loaded from when you call `load_components()`. ```py >>> t2i_pipeline.loader @@ -411,28 +455,20 @@ StableDiffusionXLModularLoader { "diffusers", "ClassifierFreeGuidance", { - "repo": null, - "revision": null, - "subfolder": null, "type_hint": [ "diffusers", "ClassifierFreeGuidance" - ], - "variant": null + ] } ], "image_processor": [ "diffusers", "VaeImageProcessor", { - "repo": null, - "revision": null, - "subfolder": null, "type_hint": [ "diffusers", "VaeImageProcessor" - ], - "variant": null + ] } ], "scheduler": [ @@ -535,6 +571,58 @@ StableDiffusionXLModularLoader { ] } ``` + +There are also a few properties that can provide a quick summary of component loading status: + +```py +# All components expected by the pipeline +>>> t2i_pipeline.loader.component_names +['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'guider', 'scheduler', 'unet', 'vae', 'image_processor'] + +# Components that are not loaded yet (will be loaded with from_pretrained) +>>> t2i_pipeline.loader.null_component_names +['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'scheduler'] + +# Components that will be loaded from pretrained models +>>> t2i_pipeline.loader.pretrained_component_names +['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'scheduler', 'unet', 'vae'] + +# Components that are created with default config (no repo needed) +>>> t2i_pipeline.loader.config_component_names +['guider', 'image_processor'] +``` + +### Modifying Loading Specs + +When you call `pipeline.load_components(names=)` or `pipeline.load_default_components()`, it uses the loading specs from the modular repository's `modular_model_index.json`. The pipeline's `loader` attribute is synced with these specs - it shows you exactly what will be loaded and from where. + +You can change where components are loaded from by default by modifying the `modular_model_index.json` in the repository. You can change any field in the loading specs: `repo`, `subfolder`, `variant`, `revision`, etc. + +```py +# Original spec in modular_model_index.json +"unet": [ + null, null, + { + "repo": "stabilityai/stable-diffusion-xl-base-1.0", + "subfolder": "unet", + "variant": "fp16" + } +] + +# Modified spec - changed repo, subfolder, and variant +"unet": [ + null, null, + { + "repo": "RunDiffusion/Juggernaut-XL-v9", + "subfolder": "unet", + "variant": "fp16" + } +] +``` + +When you call `pipeline.load_components(...)`/`pipeline.load_default_components()`, it will now load from the new repository by default. + + ### Updating components in a `ModularPipeline` Similar to `DiffusionPipeline`, You could load an components separately to replace the default one in the pipeline. But in Modular Diffusers system, you need to use `ComponentSpec` to load/create them. @@ -595,8 +683,7 @@ StableDiffusionXLModularLoader { } ``` - -### Run a `ModularPipeline` +### Running a `ModularPipeline` The API to run the `ModularPipeline` is very similar to how you would run a regular `DiffusionPipeline`: @@ -615,175 +702,458 @@ Under the hood, `ModularPipeline`'s `__call__` method is a wrapper around the pi You can inspect the docstring of a `ModularPipeline` to check what arguments the pipeline accepts and how to specify the `output` you want. It will list all available outputs (basically everything in the intermediate pipeline state) so you can choose from the list. +**Important**: It is important to always check the docstring because arguments can be different from standard pipelines that you're familar with. For example, in Modular Diffusers we standardized controlnet image input as `control_image`, but regular pipelines have inconsistencies over the names, e.g. controlnet text-to-image uses `image` while SDXL controlnet img2img uses `control_image`. + +**Note**: The `output` list might be longer than you expected - it includes everything in the intermediate state that you can choose to return. Most of the time, you'll just want `output="images"` or `output="latents"`. + ```py t2i_pipeline.doc ``` +#### Text-to-Image, Image-to-Image, and Inpainting + +These are minimum inference example for our basic tasks: text-to-image, image-to-image and inpainting. The process to create different pipelines is the same - only difference is the block classes presets. The inference is also more or less same to standard pipelines, but please always check `.doc` for correct input names and remember to pass `output="images"`. + + + + + ```py import torch from diffusers.modular_pipelines import SequentialPipelineBlocks from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS -t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS) +# create pipeline from official blocks preset +blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS) modular_repo_id = "YiYiXu/modular-loader-t2i" -t2i_pipeline = t2i_blocks.init_pipeline(modular_repo_id) +pipeline = blocks.init_pipeline(modular_repo_id) -t2i_pipeline.load_components(torch_dtype=torch.float16) -t2i_pipeline.to("cuda") +pipeline.load_default_components(torch_dtype=torch.float16) +pipeline.to("cuda") -image = t2i_pipeline(prompt="a cat", output="images")[0] +# run pipeline, need to pass a "output=images" argument +image = pipeline(prompt="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", output="images")[0] image.save("modular_t2i_out.png") ``` + + + +```py +import torch +from diffusers.modular_pipelines import SequentialPipelineBlocks +from diffusers.modular_pipelines.stable_diffusion_xl import IMAGE2IMAGE_BLOCKS + +# create pipeline from blocks preset +blocks = SequentialPipelineBlocks.from_blocks_dict(IMAGE2IMAGE_BLOCKS) + +modular_repo_id = "YiYiXu/modular-loader-t2i" +pipeline = blocks.init_pipeline(modular_repo_id) + +pipeline.load_default_components(torch_dtype=torch.float16) +pipeline.to("cuda") + +url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png" +init_image = load_image(url) +prompt = "a dog catching a frisbee in the jungle" +image = pipeline(prompt=prompt, image=init_image, strength=0.8, output="images")[0] +image.save("modular_i2i_out.png") +``` + + + + +```py +import torch +from diffusers.modular_pipelines import SequentialPipelineBlocks +from diffusers.modular_pipelines.stable_diffusion_xl import INPAINT_BLOCKS +from diffusers.utils import load_image + +# create pipeline from blocks preset +blocks = SequentialPipelineBlocks.from_blocks_dict(INPAINT_BLOCKS) + +modular_repo_id = "YiYiXu/modular-loader-t2i" +pipeline = blocks.init_pipeline(modular_repo_id) + +pipeline.load_default_components(torch_dtype=torch.float16) +pipeline.to("cuda") + +img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png" +mask_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-inpaint-mask.png" + +init_image = load_image(img_url) +mask_image = load_image(mask_url) + +prompt = "A deep sea diver floating" +image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, strength=0.85, output="images")[0] +image.save("moduar_inpaint_out.png") +``` + + + + +#### ControlNet + +For ControlNet, we provide one auto block you can place at the `denoise` step. Let's create it and inspect it to see what it tells us. + + + +💡 **How to explore new tasks**: When you want to figure out how to do a specific task in Modular Diffusers, it is a good idea to start by checking what block classes presets we offer in `ALL_BLOCKS`. Then create the block instance and inspect it - it will show you the required components, description, and sub-blocks. This is crucial for understanding what each block does and what it needs. + + + +```py +>>> from diffusers.modular_pipelines.stable_diffusion_xl import ALL_BLOCKS +>>> ALL_BLOCKS["controlnet"] +InsertableDict([ + 0: ('denoise', ) +]) +>>> controlnet_blocks = ALL_BLOCKS["controlnet"]["denoise"]() +>>> controlnet_blocks +StableDiffusionXLAutoControlnetStep( + Class: SequentialPipelineBlocks + + ==================================================================================================== + This pipeline contains blocks that are selected at runtime based on inputs. + Trigger Inputs: {'mask', 'control_mode', 'control_image', 'controlnet_cond'} + Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('mask')`). + ==================================================================================================== + + + Description: Controlnet auto step that prepare the controlnet input and denoise the latents. It works for both controlnet and controlnet_union and supports text2img, img2img and inpainting tasks. (it should be replace at 'denoise' step) + + + Components: + controlnet (`ControlNetUnionModel`) + control_image_processor (`VaeImageProcessor`) + scheduler (`EulerDiscreteScheduler`) + unet (`UNet2DConditionModel`) + guider (`ClassifierFreeGuidance`) + + Sub-Blocks: + [0] controlnet_input (StableDiffusionXLAutoControlNetInputStep) + Description: Controlnet Input step that prepare the controlnet input. + This is an auto pipeline block that works for both controlnet and controlnet_union. + (it should be called right before the denoise step) - `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided. + - `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided. - if neither `control_mode` nor `control_image` is provided, step will be skipped. + + [1] controlnet_denoise (StableDiffusionXLAutoControlNetDenoiseStep) + Description: Denoise step that iteratively denoise the latents with controlnet. This is a auto pipeline block that using controlnet for text2img, img2img and inpainting tasks.This block should not be used without a controlnet_cond input - `StableDiffusionXLInpaintControlNetDenoiseStep` (inpaint_controlnet_denoise) is used when mask is provided. - `StableDiffusionXLControlNetDenoiseStep` (controlnet_denoise) is used when mask is not provided but controlnet_cond is provided. - If neither mask nor controlnet_cond are provided, step will be skipped. + +) +``` + + + +💡 **Auto Blocks**: This is first time we meet a Auto Blocks! `AutoPipelineBlocks` automatically adapt to your inputs by combining multiple workflows with conditional logic. This is why one convenient block can work for all tasks and controlnet types. See the [Auto Blocks Guide](TODO) for more details. + + + +The block shows us it has two steps (prepare inputs + denoise) and supports all tasks with both controlnet and controlnet union. Most importantly, it tells us to place it at the 'denoise' step. Let's do exactly that: + +```py +import torch +from diffusers.modular_pipelines import SequentialPipelineBlocks +from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS, StableDiffusionXLAutoControlnetStep +from diffusers.utils import load_image + +# create pipeline from blocks preset +blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS) + +# these two lines applies controlnet +controlnet_blocks = StableDiffusionXLAutoControlnetStep() +blocks.sub_blocks["denoise"] = controlnet_blocks +``` + +Before we convert the blocks into a pipeline and load its components, let's inspect the blocks and its docs again to make sure it was assembled correctly. You should be able to see that `controlnet` and `control_image_processor` are now listed as `Components`, so we should initialize the pipeline with a repo that contains desired loading specs for these 2 components. + +```py +# make sure to a modular_repo including controlnet +modular_repo_id = "YiYiXu/modular-demo-auto" +pipeline = blocks.init_pipeline(modular_repo_id) +pipeline.load_default_components(torch_dtype=torch.float16) +pipeline.to("cuda") + +# generate +canny_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" +) +image = pipeline( + prompt="a bird", controlnet_conditioning_scale=0.5, control_image=canny_image, output="images" +)[0] +image.save("modular_control_out.png") +``` + +#### IP-Adapter + +**Challenge time!** Before we show you how to apply IP-adapter, try doing it yourself! Use the same process we just walked you through with ControlNet: check the official blocks preset, inspect the block instance and docstring `.doc`, and adapt a regular IP-adapter example to modular. + +Let's walk through the steps: + +1. Check blocks preset + +```py +>>> from diffusers.modular_pipelines.stable_diffusion_xl import ALL_BLOCKS +>>> ALL_BLOCKS["ip_adapter"] +InsertableDict([ + 0: ('ip_adapter', ) +]) +``` + +2. inspect the block & doc + +``` +>>> from diffusers.modular_pipelines.stable_diffusion_xl import StableDiffusionXLAutoIPAdapterStep +>>> ip_adapter_blocks = StableDiffusionXLAutoIPAdapterStep() +>>> ip_adapter_blocks +StableDiffusionXLAutoIPAdapterStep( + Class: AutoPipelineBlocks + + ==================================================================================================== + This pipeline contains blocks that are selected at runtime based on inputs. + Trigger Inputs: {'ip_adapter_image'} + Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('ip_adapter_image')`). + ==================================================================================================== + + + Description: Run IP Adapter step if `ip_adapter_image` is provided. This step should be placed before the 'input' step. + + + + Components: + image_encoder (`CLIPVisionModelWithProjection`) + feature_extractor (`CLIPImageProcessor`) + unet (`UNet2DConditionModel`) + guider (`ClassifierFreeGuidance`) + + Sub-Blocks: + • ip_adapter [trigger: ip_adapter_image] (StableDiffusionXLIPAdapterStep) + Description: IP Adapter step that prepares ip adapter image embeddings. + Note that this step only prepares the embeddings - in order for it to work correctly, you need to load ip adapter weights into unet via ModularPipeline.loader. + e.g. pipeline.loader.load_ip_adapter() and pipeline.loader.set_ip_adapter_scale(). + See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin) for more details + +) +``` +3. follow the instruction to build + +```py +import torch +from diffusers.modular_pipelines import SequentialPipelineBlocks +from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS + +# create pipeline from official blocks preset +blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS) -## An slightly advanced Workflow +# insert ip_adapter_blocks before the input step as instructed +blocks.sub_blocks.insert("ip_adapter", ip_adapter_blocks, 1) -We've learned the basic components of the Modular Diffusers System. Now let's tie everything together with more practical example that demonstrates the true power of Modular Diffusers: working between with multiple pipelines that can share components. +# inspec the blocks before you convert it into pipelines, +# and make sure to use a repo that contains the loading spec for all components +# for ip-adapter, you need image_encoder & feature_extractor +modular_repo_id = "YiYiXu/modular-demo-auto" +pipeline = blocks.init_pipeline(modular_repo_id) + +pipeline.load_default_components(torch_dtype=torch.float16) +pipeline.loader.load_ip_adapter( + "h94/IP-Adapter", + subfolder="sdxl_models", + weight_name="ip-adapter_sdxl.bin" +) +pipeline.loader.set_ip_adapter_scale(0.8) +pipeline.to("cuda") +``` + +4. adapt an example to modular + +We are using [this one](https://huggingface.co/docs/diffusers/using-diffusers/ip_adapter?ipadapter-variants=IP-Adapter+Plus#ip-adapter) from our IP-Adapter doc! + + +```py +from diffusers.utils import load_image +image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_diner.png") +image = pipeline( + prompt="a polar bear sitting in a chair drinking a milkshake", + ip_adapter_image=image, + negative_prompt="deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality", + output="images" +)[0] +image.save("modular_ipa_out.png") +``` + + +## A more practical example + +We've learned the basic components of the Modular Diffusers System. Now let's tie everything together with more practical example that demonstrates the true power of Modular Diffusers: working between with multiple pipelines that can share components. + +In this example, we'll generate latents from a text-to-image pipeline, then refine them with an image-to-image pipeline. We will use IP-adapter, LoRA, and ControlNet. + +Let's setup the text-to-image workflow. Instead of putting all blocks into one complete pipeline, we'll create separate `text_blocks` for encoding prompts, `t2i_blocks` for generating latents, and `decoder_blocks` for creating final images. ```py import torch -from diffusers.modular_pipelines import SequentialPipelineBlocks, ComponentsManager -from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS, IMAGE2IMAGE_BLOCKS +from diffusers.modular_pipelines import SequentialPipelineBlocks +from diffusers.modular_pipelines.stable_diffusion_xl import ALL_BLOCKS # create t2i blocks and then pop out the text_encoder step and decoder step so that we can use them in standalone manner -t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS.copy()) -text_blocks = t2i_blocks.blocks.pop("text_encoder") -decoder_blocks = t2i_blocks.blocks.pop("decode") +t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(ALL_BLOCKS["text2img"]) +text_blocks = t2i_blocks.sub_blocks.pop("text_encoder") +decoder_blocks = t2i_blocks.sub_blocks.pop("decode") +``` -# Create a refiner blocks -# - removing image_encoder a since we'll use latents from t2i -# - removing decode since we already created a seperate decoder_block -i2i_blocks_dict = IMAGE2IMAGE_BLOCKS.copy() -i2i_blocks_dict.pop("image_encoder") -i2i_blocks_dict.pop("decode") -refiner_blocks = SequentialPipelineBlocks.from_blocks_dict(i2i_blocks_dict) +Next, convert them into runnable pipelines. We'll use a Components Manager with auto offloading strategy. +**Components Manager**: Create one manager and pass it to `init_pipeline` along with a collection name. All models loaded by that pipeline will be added to the manager under that collection. + +**Auto Offloading**: All components are placed on CPU and only moved to device right before their forward pass. The manager monitors device memory and may move components off-device to make space for new ones. Unlike `DiffusionPipeline.enable_model_cpu_offload()`, this works across all components in the manager and all your workflows. + + +```py +from diffusers import ComponentsManager # Set up component manager and turn on the offloading components = ComponentsManager() components.enable_auto_cpu_offload(device="cuda") +``` -# convert all blocks into runnable pipelines: text_node, decoder_node, t2i_pipe, refiner_pipe -t2i_repo = "YiYiXu/modular-loader-t2i" -refiner_repo = "YiYiXu/modular_refiner" -dtype = torch.float16 +Since we have a modular setup where different pipelines may share components, we recommend using a standalone loader to load components all at once and add them to each pipeline with `update_components()`. + + + -text_node = text_blocks.init_pipeline(t2i_repo, component_manager=components, collection="t2i") -text_node.load_components(torch_dtype=dtype) +💡 **Load components without pipeline blocks**: +- `blocks.init_pipeline(repo)` creates a pipeline with a built-in loader that only includes components its blocks needs +- `StableDiffusionXLModularLoader.from_pretrained(repo)` set up a standalone loader that includes everything in the repo's `modular_model_index.json` -decoder_node = decoder_blocks.init_pipeline(t2i_repo, component_manager=components, collection="t2i") -decoder_node.load_components(torch_dtype=dtype) +See the [Loader Guide](TODO) for more details. -t2i_pipe = t2i_blocks.init_pipeline(t2i_repo, component_manager=components, collection="t2i") -t2i_pipe.load_components(torch_dtype=dtype) + -# for refiner pipeline, only unet is unique so we only load unet here, and we will reuse other components -refiner_pipe = refiner_blocks.init_pipeline(refiner_repo, component_manager=components, collection="refiner") -refiner_pipe.load_components(component_names="unet", torch_dtype=dtype) +```py +from diffusers import StableDiffusionXLModularLoader +t2i_repo = "YiYiXu/modular-demo-auto" +t2i_loader = StableDiffusionXLModularLoader.from_pretrained(t2i_repo, components_manager=components, collection="t2i") + +text_node = text_blocks.init_pipeline(t2i_repo, components_manager=components) +decoder_node = decoder_blocks.init_pipeline(t2i_repo, components_manager=components) +t2i_pipe = t2i_blocks.init_pipeline(t2i_repo, components_manager=components) ``` -let's inspect components manager here, you can see that 5 models are automatically registered: two text encoders, two UNets, and one VAE. The models are organized by collection - 4 models under "t2i" and one UNet under "refiner". This happens because we passed a `collection` parameter when initializing each pipeline. For example, when we created the refiner pipeline, we did `refiner_pipe = refiner_blocks.init_pipeline(refiner_repo, component_manager=components, collection="refiner")`. All models loaded by `refiner_pipe.load_components(...)` are automatically placed under the "refiner" collection. +We'll load components in `t2i_loader`. You can get the list of all loadable components from loader's `pretrained_component_names` property. -Notice that all models are currently on CPU with execution device "cuda:0" - this is due to the auto CPU offloading strategy we enabled with `components.enable_auto_cpu_offload(device="cuda")`. +```py +>>> t2i_loader.pretrained_component_names +['controlnet', 'image_encoder', 'scheduler', 'text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'unet', 'vae'] +``` -The manager also displays useful info like dtype and memory size for each model. +It include controlnet and image_encoder for ip-adapter that we don't need now. But I'll load them anyway since they'll stay on CPU and I might use them later. But you can choose what to load in the `names` argument. + +```py +import torch +# inspect before you load +# t2i_loader +t2i_loader.load(t2i_loader.pretrained_component_names, torch_dtype=torch.float16) +``` +All the models are registered to components manager under the collection "t2i". ```py >>> components Components: -====================================================================================================================================================================================== +============================================================================================================================================================ Models: --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- -Name | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- -text_encoder_2 | CLIPTextModelWithProjection | cpu(cuda:0) | torch.float16 | 1.29 | stabilityai/stable-diffusion-xl-base-1.0|text_encoder_2|null|null | t2i -text_encoder | CLIPTextModel | cpu(cuda:0) | torch.float16 | 0.23 | stabilityai/stable-diffusion-xl-base-1.0|text_encoder|null|null | t2i -unet | UNet2DConditionModel | cpu(cuda:0) | torch.float16 | 4.78 | RunDiffusion/Juggernaut-XL-v9|unet|fp16|null | t2i -unet | UNet2DConditionModel | cpu(cuda:0) | torch.float16 | 4.21 | stabilityai/stable-diffusion-xl-refiner-1.0|unet|null|null | refiner -vae | AutoencoderKL | cpu(cuda:0) | torch.float16 | 0.16 | madebyollin/sdxl-vae-fp16-fix|null|null|null | t2i --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- +------------------------------------------------------------------------------------------------------------------------------------------------------------ +Name | Class | Device: act(exec)| Dtype | Size (GB)| Load ID | Collection +------------------------------------------------------------------------------------------------------------------------------------------------------------ +vae | AutoencoderKL | cpu(cuda:0) | torch.float16| 0.16 | SG161222/RealVisXL_V4.0|vae|null|null | t2i +image_encoder | CLIPVisionModelWithProjection| cpu(cuda:0) | torch.float16| 3.44 | h94/IP-Adapter|sdxl_models/image_encoder|null|null | t2i +text_encoder | CLIPTextModel | cpu(cuda:0) | torch.float16| 0.23 | SG161222/RealVisXL_V4.0|text_encoder|null|null | t2i +unet | UNet2DConditionModel | cpu(cuda:0) | torch.float16| 4.78 | SG161222/RealVisXL_V4.0|unet|null|null | t2i +text_encoder_2 | CLIPTextModelWithProjection | cpu(cuda:0) | torch.float16| 1.29 | SG161222/RealVisXL_V4.0|text_encoder_2|null|null | t2i +controlnet | ControlNetModel | cpu(cuda:0) | torch.float16| 2.33 | diffusers/controlnet-canny-sdxl-1.0|null|null|null | t2i +------------------------------------------------------------------------------------------------------------------------------------------------------------ Other Components: --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- -Name | Class | Collection --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- -tokenizer | CLIPTokenizer | t2i -tokenizer_2 | CLIPTokenizer | t2i -scheduler | EulerDiscreteScheduler | t2i --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- +------------------------------------------------------------------------------------------------------------------------------------------------------------ +Name | Class | Collection +------------------------------------------------------------------------------------------------------------------------------------------------------------ +tokenizer_2 | CLIPTokenizer | t2i +tokenizer | CLIPTokenizer | t2i +scheduler | EulerDiscreteScheduler | t2i +------------------------------------------------------------------------------------------------------------------------------------------------------------ Additional Component Info: ================================================== ``` +Let's add the loaded components to each pipeline. We'll follow this pattern for each pipeline: +1. Check what components the pipeline needs: inspect `pipeline.loader` or use `loader.null_component_names` +2. Get them from the components manager: use its `search_models()`/`get_one`/`get_components_from_names` method +3. Update the pipeline: `pipeline.update_components()` +4. Verify the components are loaded correctly: inspect `pipeline.loader` as well as components manager -Now let's reuse components from the t2i pipeline in the refiner. First, let's check the loading status of the refiner pipeline to understand what components are needed: +We will start with `decoder_node`. First, check what components it needs: ```py ->>> refiner_pipe.loader +>>> decoder_node.loader.null_component_names +['vae'] ``` +The pipeline only needs a `vae`. Looking at the components manager table, there's only one VAE available: -Looking at the loader output, you can see that `text_encoder` and `tokenizer` have empty loading spec maps (their `repo` fields are `null`), this is because refiner pipeline does not use these two components so they are not listed in the `modular_model_index.json` in `refiner_repo`. The `unet` is already correctly loaded from the refiner repository. We need to load the remaining components: `vae`, `text_encoder_2`, `tokenizer_2`, and `scheduler`. Since these components are already available in the t2i collection, we can reuse them instead of loading duplicates. - -Now let's reuse the components from the t2i pipeline in the refiner. We use the`|` to select multiple components from components manager at once: +``` +Name | Class | Device: act(exec)| Dtype | Size (GB)| Load ID | Collection +---------------------------------------------------------------------------------------------------------------------- +vae | AutoencoderKL| cpu(cuda:0) | torch.float16| 0.16 | SG161222/RealVisXL_V4.0|vae|null|null | t2i +``` +Since there's only one VAE, we can get it using its unique Load ID: ```py -# Reuse components from t2i pipeline (select everything at once) -reuse_components = components.get("text_encoder_2|scheduler|vae|tokenizer_2", as_name_component_tuples=True) -refiner_pipe.update_components(**dict(reuse_components)) +vae = components.get_one(load_id="SG161222/RealVisXL_V4.0|vae|null|null") +decoder_node.update_components(vae=vae) ``` -You'll see warnings indicating that these components already exist in the components manager: +Verify it's correctly loaded: -```out -component 'text_encoder_2' already exists as 'text_encoder_2_238ae9a7-c864-4837-a8a2-f58ed753b2d0' -component 'tokenizer_2' already exists as 'tokenizer_2_b795af3d-f048-4b07-a770-9e8237a2be2d' -component 'scheduler' already exists as 'scheduler_e3435f63-266a-4427-9383-eb812e830fe8' -component 'vae' already exists as 'vae_357eee6a-4a06-46f1-be83-494f7d60ca69' +```py +decoder_node.loader ``` +Now let's do the same for `text_node`. Get the list of components the pipeline needs to load: -These warnings are expected and indicate that the components manager is correctly identifying that these components are already loaded. The system will reuse the existing components rather than creating duplicates. +```py +>>> text_node.loader.null_component_names +['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2'] +``` +Pass the list directly to the components manager to get the components and add it to the pipeline + +```py +text_components = components.get_components_by_names(text_node.loader.null_component_names) +# Add components to pipeline +text_node.update_components(**text_components) + +# Verify components are loaded +assert not text_node.loader.null_component_names +text_node.loader +``` -Let's check the components manager again to see the updated state: +Finally, let's set up `t2i_pipe`: ```py ->>> components -Components: -====================================================================================================================================================================================== -Models: --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- -Name | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- -text_encoder | CLIPTextModel | cpu(cuda:0) | torch.float16 | 0.23 | stabilityai/stable-diffusion-xl-base-1.0|text_encoder|null|null | t2i -text_encoder_2 | CLIPTextModelWithProjection | cpu(cuda:0) | torch.float16 | 1.29 | stabilityai/stable-diffusion-xl-base-1.0|text_encoder_2|null|null | t2i - | | | | | | refiner -vae | AutoencoderKL | cpu(cuda:0) | torch.float16 | 0.16 | madebyollin/sdxl-vae-fp16-fix|null|null|null | t2i - | | | | | | refiner -unet | UNet2DConditionModel | cpu(cuda:0) | torch.float16 | 4.78 | RunDiffusion/Juggernaut-XL-v9|unet|fp16|null | t2i -unet | UNet2DConditionModel | cpu(cuda:0) | torch.float16 | 4.21 | stabilityai/stable-diffusion-xl-refiner-1.0|unet|null|null | refiner --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- -Other Components: --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- -Name | Class | Collection --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- -tokenizer_2 | CLIPTokenizer | t2i - | | refiner -tokenizer | CLIPTokenizer | t2i -scheduler | EulerDiscreteScheduler | t2i - | | refiner --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- +# Get unet & scheduler from components manager and add to pipeline +comps = components.get_components_by_names(t2i_pipe.loader.null_component_names) +t2i_pipe.update_components(**comps) -Additional Component Info: -================================================== -``` +# Verify everything is loaded +assert not t2i_pipe.loader.null_component_names +t2i_pipe.loader -Notice how `text_encoder_2`, `vae`, `tokenizer_2`, and `scheduler` now appear under both "t2i" and "refiner" collections. +# Verify components manager hasn't changed (we only reused existing components) +components +``` -We can start to generate an image with the t2i pipeline and refine it. +We can start to generate an image with the t2i pipeline. First to run the prompt through text_node to get prompt embeddings @@ -794,79 +1164,197 @@ First to run the prompt through text_node to get prompt embeddings ```py -prompt = "A crystal orb resting on a wooden table with a yellow rubber duck, surrounded by aged scrolls and alchemy tools, illuminated by candlelight, detailed texture, high resolution image" - +prompt = "an astronaut" text_embeddings = text_node(prompt=prompt, output=["prompt_embeds","negative_prompt_embeds", "pooled_prompt_embeds", "negative_pooled_prompt_embeds"]) ``` -Now generate latents with t2i pipeline and then refine with refiner. Note that both our `t2i_pipe` and `refiner_pipe` do not have decoder steps since we separated them out earlier, so we need to use `output="latents"` instead of `output="images"`. +Now generate latents with t2i pipeline and then decode with decoder. + + +```py +generator = torch.Generator(device="cuda").manual_seed(0) +latents_t2i = t2i_pipe(**text_embeddings, num_inference_steps=25, generator=generator, output="latents") +image = decoder_node(latents=latents_t2i, output="images")[0] +image.save("modular_part2_t2i.png") + +``` + +Now let's add a LoRA to our pipeline. With the modular approach we will be able to reuse intermediate outputs from blocks that otherwise needs to be re-run. Let's load the LoRA weights and see what happens: + +```py +t2i_loader.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy_face") +components +``` +Notice that the "Additional Component Info" section shows that only the `unet` component has the LoRA adapter loaded. This means we can skip the text encoding step and reuse the existing embeddings, making the generation much faster. + +```out +Components: +============================================================================================================================================================ +... +Additional Component Info: +================================================== + +unet: + Adapters: ['toy_face'] +``` + -💡 `t2i_pipe.blocks` shows you what steps this pipeline takes. You can see that our `t2i_pipe` no longer includes the `text_encoder` and `decode` steps since we removed them earlier when we popped them out to create separate nodes. +🔍 Alternatively, you can find a component's ID and then use `get_model_info` to get detailed metadata about that component: ```py ->>> t2i_pipe.blocks -SequentialPipelineBlocks( - Class: ModularPipelineBlocks +id = components.get_ids("unet")[0] +components.get_model_info(id) +# {'model_id': 'unet_6c2b839d-ec39-4ce9-8741-333ba6d25932', 'added_time': 1751101289.203884, 'collection': 't2i', 'class_name': 'UNet2DConditionModel', 'size_gb': 4.940812595188618, 'adapters': ['toy_face'], 'has_hook': True, 'execution_device': device(type='cuda', index=0)} +``` + - Description: +```py +generator = torch.Generator(device="cuda").manual_seed(0) +latents_lora = t2i_pipe(**text_embeddings, num_inference_steps=25, generator=generator, output="latents") +image = decoder_node(latents=latents_lora, output="images")[0] +image.save("modular_part2_lora.png") +``` - Components: - scheduler (`EulerDiscreteScheduler`) - guider (`ClassifierFreeGuidance`) - unet (`UNet2DConditionModel`) +IP-adapter can also be used as a standalone pipeline. We can generate the embeddings once and reuse them for different workflows. - Blocks: - [0] input (StableDiffusionXLInputStep) - Description: Input processing step that: - 1. Determines `batch_size` and `dtype` based on `prompt_embeds` - 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt` - - All input tensors are expected to have either batch_size=1 or match the batch_size - of prompt_embeds. The tensors will be duplicated across the batch dimension to - have a final batch_size of batch_size * num_images_per_prompt. +```py +from diffusers.utils import load_image - [1] set_timesteps (StableDiffusionXLSetTimestepsStep) - Description: Step that sets the scheduler's timesteps for inference +ipa_blocks = ALL_BLOCKS["ip_adapter"]["ip_adapter"]() +ipa_node = ipa_blocks.init_pipeline(t2i_repo, components_manager=components) +comps = components.get_components_by_names(ipa_node.loader.null_component_names) +ipa_node.update_components(**comps) - [2] prepare_latents (StableDiffusionXLPrepareLatentsStep) - Description: Prepare latents step that prepares the latents for the text-to-image generation process +t2i_loader.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin") +t2i_loader.set_ip_adapter_scale(0.6) - [3] prepare_add_cond (StableDiffusionXLPrepareAdditionalConditioningStep) - Description: Step that prepares the additional conditioning for the text-to-image generation process +# check it's correctly loaded +assert not ipa_node.loader.null_component_names +ipa_node.loader +# find out inputs/outputs +print(ipa_node.doc) - [4] denoise (StableDiffusionXLDenoiseLoop) - Description: Denoise step that iteratively denoise the latents. - Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method - At each iteration, it runs blocks defined in `blocks` sequencially: - - `StableDiffusionXLLoopBeforeDenoiser` - - `StableDiffusionXLLoopDenoiser` - - `StableDiffusionXLLoopAfterDenoiser` - +ip_adapter_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img5.png") +ipa_embeddings = ipa_node(ip_adapter_image=ip_adapter_image, output=["ip_adapter_embeds","negative_ip_adapter_embeds"]) -) +generator = torch.Generator(device="cuda").manual_seed(0) +latents_ipa = t2i_pipe(**text_embeddings, **ipa_embeddings, num_inference_steps=25, generator=generator, output="latents") + +image = decoder_node(latents=latents_ipa, output="images")[0] +image.save("modular_part2_lora_ipa.png") +``` + +We can create a new ControlNet workflow by modifying the pipeline blocks, reusing components as much as possible, and see how it affects the generation. + +We want to use a different ControlNet from the one that's already loaded. + +```py +from diffusers import ComponentSpec, ControlNetModel +control_blocks = ALL_BLOCKS["controlnet"]["denoise"]() +# update the t2i_blocks and create pipeline +t2i_blocks.sub_blocks["denoise"] = control_blocks +t2i_control_pipe = t2i_blocks.init_pipeline(t2i_repo, components_manager=components) + +# fetch the controlnet_pose seperately since we need to change name when adding it to the pipeline +controlnet_spec = ComponentSpec(name="controlnet_pose", type_hint=ControlNetModel, repo="thibaud/controlnet-openpose-sdxl-1.0") +controlnet = controlnet_spec.load(torch_dtype=torch.float16) +t2i_control_pipe.update_components(controlnet=controlnet) + +# fetch the rest of the components from the components manager +comps = components.get_components_by_names(t2i_control_pipe.loader.null_component_names) +t2i_control_pipe.update_components(**comps) + +control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/controlnet/person_pose.png") +generator = torch.Generator(device="cuda").manual_seed(0) +latents_control = t2i_control_pipe(**text_embeddings, **ipa_embeddings, control_image=control_image, num_inference_steps=25, generator=generator, output="latents") + +image = decoder_node(latents=latents_control, output="images")[0] +image.save("modular_part2_lora_ipa_control.png") ``` - + +Now set up refiner workflow. For refiner blocks, we removed `image_encoder` since the refiner works with latents directly, and `decoder` since we already have a dedicated one. We keep `text_encoder` because SDXL refiner encodes text prompts differently from the text-to-image pipeline, so we cannot share it. ```py -latents = t2i_pipe(**text_embeddings, num_inference_steps=25, output="latents") -refined_latents = refiner_pipe(image_latents=latents, prompt=prompt, num_inference_steps=10, output="latents") +# Create a refiner blocks +# - removing image_encoder a since we'll use latents from t2i +# - removing decode since we already created a seperate decoder_block +refiner_blocks = SequentialPipelineBlocks.from_blocks_dict(ALL_BLOCKS["img2img"]) +refiner_blocks.sub_blocks.pop("image_encoder") +refiner_blocks.sub_blocks.pop("decode") ``` -To get the final images, we need to pass the latents through our separate decoder node: +Create refiner pipeline. refiner has a different unet and use only one text_encoder so it is hosted in a different repo. We pass the same components manager to refiner pipeline, along with a unique "refiner" collection. ```py -image = decoder_node(latents=latents, output="images")[0] -refined_image = decoder_node(latents=refined_latents, output="images")[0] +refiner_repo = "YiYiXu/modular_refiner" +refiner_pipe = refiner_blocks.init_pipeline(refiner_repo, components_manager=components, collection="refiner") +``` + + +We want to reuse components from the t2i pipeline in the refiner as much as possible. First, let's check the loading status of the refiner pipeline to understand what components are needed: + +```py +>>> refiner_pipe.loader +``` + +Looking at the loader output, you can see that `text_encoder` and `tokenizer` have empty loading spec maps (their `repo` fields are `null`), this is because refiner pipeline does not use these two components so they are not listed in the `modular_model_index.json` in `refiner_repo`. The `unet` is different from the one we loaded for text-to-image. The remaining components: `vae`, `text_encoder_2`, `tokenizer_2`, and `scheduler` are already available in the t2i collection, we can reuse them instead of loading duplicates. + +```py +refiner_pipe.load_components(names="unet", torch_dtype=torch.float16) + +# verify loaded correctly +refiner_pipe.loader + +# veryfiy registered to components manager under refiner +components +``` + +Now let's reuse the components from the t2i pipeline in the refiner. We use the`|` to select multiple components from components manager at once: + +```py +# Reuse components from t2i pipeline (select everything at once) +reuse_components = components.search_components("text_encoder_2|scheduler|vae|tokenizer_2") +refiner_pipe.update_components(**reuse_components) +``` + +You'll see warnings indicating that these components already exist in the components manager: + +```out +component 'text_encoder_2' already exists as 'text_encoder_2_238ae9a7-c864-4837-a8a2-f58ed753b2d0' +component 'tokenizer_2' already exists as 'tokenizer_2_b795af3d-f048-4b07-a770-9e8237a2be2d' +component 'scheduler' already exists as 'scheduler_e3435f63-266a-4427-9383-eb812e830fe8' +component 'vae' already exists as 'vae_357eee6a-4a06-46f1-be83-494f7d60ca69' ``` -## YiYi TODO: maybe more on controlnet/lora/ip-adapter +These warnings are expected and indicate that the components manager is correctly identifying that these components are already loaded. The system will reuse the existing components rather than creating duplicates. +Let's check the components manager again to see the updated state. You should see `text_encoder_2`, `vae`, `tokenizer_2`, and `scheduler` now appear under both "t2i" and "refiner" collections. +Now let's refine! +```py +# refine the latents from base text-to-image workflow +refined_latents = refiner_pipe(image_latents=latents_t2i, prompt=prompt, num_inference_steps=10, output="latents") +refined_image = decoder_node(latents=refined_latents, output="images")[0] +refined_image.save("modular_part2_t2i_refine_out.png") +# refine the latents from the text-to-image lora workflow +refined_latents = refiner_pipe(image_latents=latents_lora, prompt=prompt, num_inference_steps=10, output="latents") +refined_image = decoder_node(latents=refined_latents, output="images")[0] +refined_image.save("modular_part2_lora_refine_out.png") + +# refine the latents from the text-to-image + lora + ip-adapter workflow +refined_latents = refiner_pipe(image_latents=latents_ipa, prompt=prompt, num_inference_steps=10, output="latents") +refined_image = decoder_node(latents=refined_latents, output="images")[0] +refined_image.save("modular_part2_ipa_refine_out.png") +# refine the latents from the text-to-image + lora + ip-adapter + controlnet workflow +refined_latents = refiner_pipe(image_latents=latents_control, prompt=prompt, num_inference_steps=10, output="latents") +refined_image = decoder_node(latents=refined_latents, output="images")[0] +refined_image.save("modular_part2_control_refine_out.png") +``` From 49ea4d1bf557f4b7afd3836ce48b540991d44d9e Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 28 Jun 2025 12:50:11 +0200 Subject: [PATCH 117/170] style --- .../modular_pipelines/components_manager.py | 89 ++++++++++++------- .../modular_pipelines/modular_pipeline.py | 29 ++++-- .../stable_diffusion_xl/__init__.py | 8 +- .../modular_blocks_presets.py | 4 +- .../stable_diffusion_xl/modular_loader.py | 4 +- 5 files changed, 81 insertions(+), 53 deletions(-) diff --git a/src/diffusers/modular_pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py index 3394c67cb00a..f2a2b0b0801b 100644 --- a/src/diffusers/modular_pipelines/components_manager.py +++ b/src/diffusers/modular_pipelines/components_manager.py @@ -232,8 +232,18 @@ def search_best_candidate(module_sizes, min_memory_offload): class ComponentsManager: - _available_info_fields = ["model_id", "added_time", "collection", "class_name", "size_gb", "adapters", "has_hook", "execution_device", "ip_adapter"] - + _available_info_fields = [ + "model_id", + "added_time", + "collection", + "class_name", + "size_gb", + "adapters", + "has_hook", + "execution_device", + "ip_adapter", + ] + def __init__(self): self.components = OrderedDict() self.added_time = OrderedDict() # Store when components were added @@ -241,10 +251,16 @@ def __init__(self): self.model_hooks = None self._auto_offload_enabled = False - def _lookup_ids(self, name: Optional[str] = None, collection: Optional[str] = None, load_id: Optional[str] = None, components: Optional[OrderedDict] = None): + def _lookup_ids( + self, + name: Optional[str] = None, + collection: Optional[str] = None, + load_id: Optional[str] = None, + components: Optional[OrderedDict] = None, + ): """ - Lookup component_ids by name, collection, or load_id. Does not support pattern matching. - Returns a set of component_ids + Lookup component_ids by name, collection, or load_id. Does not support pattern matching. Returns a set of + component_ids """ if components is None: components = self.components @@ -318,10 +334,14 @@ def add(self, name, component, collection: Optional[str] = None): if component_id not in self.collections[collection]: comp_ids_in_collection = self._lookup_ids(name=name, collection=collection) for comp_id in comp_ids_in_collection: - logger.warning(f"ComponentsManager: removing existing {name} from collection '{collection}': {comp_id}") + logger.warning( + f"ComponentsManager: removing existing {name} from collection '{collection}': {comp_id}" + ) self.remove(comp_id) self.collections[collection].add(component_id) - logger.info(f"ComponentsManager: added component '{name}' in collection '{collection}': {component_id}") + logger.info( + f"ComponentsManager: added component '{name}' in collection '{collection}': {component_id}" + ) else: logger.info(f"ComponentsManager: added component '{name}' as '{component_id}'") @@ -379,40 +399,43 @@ def search_components( - "unet*|vae*" : anything with base name starting with "unet" OR starting with "vae" collection: Optional collection to filter by load_id: Optional load_id to filter by - return_dict_with_names: If True, returns a dictionary with component names as keys, throw an error if multiple components with the same name are found - If False, returns a dictionary with component IDs as keys + return_dict_with_names: + If True, returns a dictionary with component names as keys, throw an error if + multiple components with the same name are found If False, returns a dictionary + with component IDs as keys Returns: - Dictionary mapping component names to components if return_dict_with_names=True, - or a dictionary mapping component IDs to components if return_dict_with_names=False + Dictionary mapping component names to components if return_dict_with_names=True, or a dictionary mapping + component IDs to components if return_dict_with_names=False """ # select components based on collection and load_id filters selected_ids = self._lookup_ids(collection=collection, load_id=load_id) components = {k: self.components[k] for k in selected_ids} - + def get_return_dict(components, return_dict_with_names): """ - Create a dictionary mapping component names to components if return_dict_with_names=True, - or a dictionary mapping component IDs to components if return_dict_with_names=False, - throw an error if duplicate component names are found when return_dict_with_names=True + Create a dictionary mapping component names to components if return_dict_with_names=True, or a dictionary + mapping component IDs to components if return_dict_with_names=False, throw an error if duplicate component + names are found when return_dict_with_names=True """ if return_dict_with_names: dict_to_return = {} for comp_id, comp in components.items(): comp_name = self._id_to_name(comp_id) if comp_name in dict_to_return: - raise ValueError(f"Duplicate component names found in the search results: {comp_name}, please set `return_dict_with_names=False` to return a dictionary with component IDs as keys") + raise ValueError( + f"Duplicate component names found in the search results: {comp_name}, please set `return_dict_with_names=False` to return a dictionary with component IDs as keys" + ) dict_to_return[comp_name] = comp return dict_to_return else: return components - # if no names are provided, return the filtered components as it is if names is None: return get_return_dict(components, return_dict_with_names) - + # if names is not a string, raise an error elif not isinstance(names, str): raise ValueError(f"Invalid type for `names: {type(names)}, only support string") @@ -488,9 +511,7 @@ def matches_pattern(component_id, pattern, exact_match=False): } if is_not_pattern: - logger.info( - f"Getting all components except those with base name '{names}': {list(matches.keys())}" - ) + logger.info(f"Getting all components except those with base name '{names}': {list(matches.keys())}") else: logger.info(f"Getting components with base name '{names}': {list(matches.keys())}") @@ -584,8 +605,8 @@ def disable_auto_cpu_offload(self): # YiYi TODO: (1) add quantization info def get_model_info( - self, - component_id: str, + self, + component_id: str, fields: Optional[Union[str, List[str]]] = None, ) -> Optional[Dict[str, Any]]: """Get comprehensive information about a component. @@ -603,7 +624,7 @@ def get_model_info( raise ValueError(f"Component '{component_id}' not found in ComponentsManager") component = self.components[component_id] - + # Validate fields if specified if fields is not None: if isinstance(fields, str): @@ -662,7 +683,7 @@ def get_model_info( return {k: v for k, v in info.items() if k in fields} else: return info - + # YiYi TODO: (1) add display fields, allow user to set which fields to display in the comnponents table def __repr__(self): # Handle empty components case @@ -820,11 +841,9 @@ def get_one( load_id: Optional[str] = None, ) -> Any: """ - Get a single component by either: - (1) searching name (pattern matching), collection, or load_id. - (2) passing in a component_id - Raises an error if multiple components match or none are found. - support pattern matching for name + Get a single component by either: (1) searching name (pattern matching), collection, or load_id. (2) passing in + a component_id Raises an error if multiple components match or none are found. support pattern matching for + name Args: component_id: Optional component ID to get @@ -841,7 +860,7 @@ def get_one( if component_id is not None and (name is not None or collection is not None or load_id is not None): raise ValueError("If searching by component_id, do not pass name, collection, or load_id") - + # search by component_id if component_id is not None: if component_id not in self.components: @@ -857,7 +876,6 @@ def get_one( raise ValueError(f"Multiple components found matching '{name}': {list(results.keys())}") return next(iter(results.values())) - def get_ids(self, names: Union[str, List[str]] = None, collection: Optional[str] = None): """ @@ -869,7 +887,7 @@ def get_ids(self, names: Union[str, List[str]] = None, collection: Optional[str] for name in names: ids.update(self._lookup_ids(name=name, collection=collection)) return list(ids) - + def get_components_by_ids(self, ids: List[str], return_dict_with_names: Optional[bool] = True): """ Get components by a list of IDs. @@ -881,7 +899,9 @@ def get_components_by_ids(self, ids: List[str], return_dict_with_names: Optional for comp_id, comp in components.items(): comp_name = self._id_to_name(comp_id) if comp_name in dict_to_return: - raise ValueError(f"Duplicate component names found in the search results: {comp_name}, please set `return_dict_with_names=False` to return a dictionary with component IDs as keys") + raise ValueError( + f"Duplicate component names found in the search results: {comp_name}, please set `return_dict_with_names=False` to return a dictionary with component IDs as keys" + ) dict_to_return[comp_name] = comp return dict_to_return else: @@ -894,6 +914,7 @@ def get_components_by_names(self, names: List[str], collection: Optional[str] = ids = self.get_ids(names, collection) return self.get_components_by_ids(ids) + def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]: """Summarizes a dictionary by finding common prefixes that share the same value. diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 7bb393633934..6bdd2f3f3659 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -1849,23 +1849,30 @@ def dtype(self) -> torch.dtype: return module.dtype return torch.float32 - + @property def null_component_names(self) -> List[str]: return [name for name in self._component_specs.keys() if hasattr(self, name) and getattr(self, name) is None] - + @property def component_names(self) -> List[str]: return list(self.components.keys()) - + @property def pretrained_component_names(self) -> List[str]: - return [name for name in self._component_specs.keys() if self._component_specs[name].default_creation_method == "from_pretrained"] - + return [ + name + for name in self._component_specs.keys() + if self._component_specs[name].default_creation_method == "from_pretrained" + ] + @property def config_component_names(self) -> List[str]: - return [name for name in self._component_specs.keys() if self._component_specs[name].default_creation_method == "from_config"] - + return [ + name + for name in self._component_specs.keys() + if self._component_specs[name].default_creation_method == "from_config" + ] @property def components(self) -> Dict[str, Any]: @@ -2430,9 +2437,13 @@ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = raise ValueError(f"Output '{output}' is not a valid output type") def load_default_components(self, **kwargs): - names = [name for name in self.loader._component_specs.keys() if self.loader._component_specs[name].default_creation_method == "from_pretrained"] + names = [ + name + for name in self.loader._component_specs.keys() + if self.loader._component_specs[name].default_creation_method == "from_pretrained" + ] self.loader.load(names=names, **kwargs) - + def load_components(self, names: Union[List[str], str], **kwargs): self.loader.load(names=names, **kwargs) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py index 9adb0527958c..95461cfc23c9 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py @@ -23,18 +23,18 @@ else: _import_structure["encoders"] = ["StableDiffusionXLTextEncoderStep"] _import_structure["modular_blocks_presets"] = [ + "ALL_BLOCKS", "AUTO_BLOCKS", "CONTROLNET_BLOCKS", "IMAGE2IMAGE_BLOCKS", "INPAINT_BLOCKS", "IP_ADAPTER_BLOCKS", - "ALL_BLOCKS", "TEXT2IMAGE_BLOCKS", "StableDiffusionXLAutoBlocks", + "StableDiffusionXLAutoControlnetStep", "StableDiffusionXLAutoDecodeStep", "StableDiffusionXLAutoIPAdapterStep", "StableDiffusionXLAutoVaeEncoderStep", - "StableDiffusionXLAutoControlnetStep", ] _import_structure["modular_loader"] = ["StableDiffusionXLModularLoader"] @@ -49,18 +49,18 @@ StableDiffusionXLTextEncoderStep, ) from .modular_blocks_presets import ( + ALL_BLOCKS, AUTO_BLOCKS, CONTROLNET_BLOCKS, IMAGE2IMAGE_BLOCKS, INPAINT_BLOCKS, IP_ADAPTER_BLOCKS, - ALL_BLOCKS, TEXT2IMAGE_BLOCKS, StableDiffusionXLAutoBlocks, + StableDiffusionXLAutoControlnetStep, StableDiffusionXLAutoDecodeStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep, - StableDiffusionXLAutoControlnetStep, ) from .modular_loader import StableDiffusionXLModularLoader else: diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks_presets.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks_presets.py index d28eb5387a46..fee955411c02 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks_presets.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks_presets.py @@ -76,9 +76,7 @@ class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks): @property def description(self): - return ( - "Run IP Adapter step if `ip_adapter_image` is provided. This step should be placed before the 'input' step.\n" - ) + return "Run IP Adapter step if `ip_adapter_image` is provided. This step should be placed before the 'input' step.\n" # before_denoise: text2img diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py index 34222444dae3..c161c9290f28 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py @@ -44,7 +44,6 @@ class StableDiffusionXLModularLoader( StableDiffusionXLLoraLoaderMixin, ModularIPAdapterMixin, ): - @property def default_height(self): return self.default_sample_size * self.vae_scale_factor @@ -52,8 +51,7 @@ def default_height(self): @property def default_width(self): return self.default_sample_size * self.vae_scale_factor - - + @property def default_sample_size(self): default_sample_size = 128 From 92b6b43805f5728c2678b4e4e239eb4867bd5326 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 28 Jun 2025 13:39:45 +0200 Subject: [PATCH 118/170] add some visuals --- docs/source/en/modular_diffusers/quicktour.md | 39 ++++++++++++++++++- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/docs/source/en/modular_diffusers/quicktour.md b/docs/source/en/modular_diffusers/quicktour.md index d9008c5e0c4e..09f656978296 100644 --- a/docs/source/en/modular_diffusers/quicktour.md +++ b/docs/source/en/modular_diffusers/quicktour.md @@ -984,13 +984,14 @@ image = pipeline( image.save("modular_ipa_out.png") ``` - -## A more practical example +## Building Advanced Workflows: The Modular Way We've learned the basic components of the Modular Diffusers System. Now let's tie everything together with more practical example that demonstrates the true power of Modular Diffusers: working between with multiple pipelines that can share components. In this example, we'll generate latents from a text-to-image pipeline, then refine them with an image-to-image pipeline. We will use IP-adapter, LoRA, and ControlNet. +### Base Text-to-Image + Let's setup the text-to-image workflow. Instead of putting all blocks into one complete pipeline, we'll create separate `text_blocks` for encoding prompts, `t2i_blocks` for generating latents, and `decoder_blocks` for creating final images. @@ -1179,6 +1180,8 @@ image.save("modular_part2_t2i.png") ``` +### Lora + Now let's add a LoRA to our pipeline. With the modular approach we will be able to reuse intermediate outputs from blocks that otherwise needs to be re-run. Let's load the LoRA weights and see what happens: ```py @@ -1218,6 +1221,8 @@ image = decoder_node(latents=latents_lora, output="images")[0] image.save("modular_part2_lora.png") ``` +### IP-adapter + IP-adapter can also be used as a standalone pipeline. We can generate the embeddings once and reuse them for different workflows. ```py @@ -1247,6 +1252,8 @@ image = decoder_node(latents=latents_ipa, output="images")[0] image.save("modular_part2_lora_ipa.png") ``` +### ControlNet + We can create a new ControlNet workflow by modifying the pipeline blocks, reusing components as much as possible, and see how it affects the generation. We want to use a different ControlNet from the one that's already loaded. @@ -1287,6 +1294,8 @@ refiner_blocks.sub_blocks.pop("image_encoder") refiner_blocks.sub_blocks.pop("decode") ``` +### Refiner + Create refiner pipeline. refiner has a different unet and use only one text_encoder so it is hosted in a different repo. We pass the same components manager to refiner pipeline, along with a unique "refiner" collection. ```py @@ -1358,3 +1367,29 @@ refined_image = decoder_node(latents=refined_latents, output="images")[0] refined_image.save("modular_part2_control_refine_out.png") ``` + +### Results + +Here are the results from our modular pipeline examples. You can find all the generated images in the [Hugging Face dataset](https://huggingface.co/datasets/YiYiXu/testing-images/tree/main/modular_quicktour). + +#### Base Text-to-Image Generation +| Base Text-to-Image | Base Text-to-Image (Refined) | +|-------------------|------------------------------| +| ![Base T2I](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/modular_part2_t2i.png) | ![Base T2I Refined](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/modular_part2_t2i_refine_out.png) | + +#### LoRA +| LoRA | LoRA (Refined) | +|-------------------|------------------------------| +| ![LoRA](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/modular_part2_lora.png) | ![LoRA Refined](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/modular_part2_lora_refine_out.png) | + +#### LoRA + IP-Adapter +| LoRA + IP-Adapter | LoRA + IP-Adapter (Refined) | +|-------------------|------------------------------| +| ![LoRA + IP-Adapter](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/modular_part2_ipa.png) | ![LoRA + IP-Adapter Refined](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/modular_part2_ipa_refine_out.png) | + +### ControlNet + LoRA + IP-Adapter +| ControlNet + LoRA + IP-Adapter | ControlNet + LoRA + IP-Adapter (Refined) | +|-------------------|------------------------------| +| ![ControlNet + LoRA + IP-Adapter](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/modular_part2_control.png) | ![ControlNet + LoRA + IP-Adapter Refined](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/modular_part2_control_refine_out.png) | + + From 8c680bc0b42a3644c464aeff1ebd5e414d2a6119 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 28 Jun 2025 14:11:17 +0200 Subject: [PATCH 119/170] up --- docs/source/en/modular_diffusers/quicktour.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/modular_diffusers/quicktour.md b/docs/source/en/modular_diffusers/quicktour.md index 09f656978296..172089bc6171 100644 --- a/docs/source/en/modular_diffusers/quicktour.md +++ b/docs/source/en/modular_diffusers/quicktour.md @@ -1370,7 +1370,7 @@ refined_image.save("modular_part2_control_refine_out.png") ### Results -Here are the results from our modular pipeline examples. You can find all the generated images in the [Hugging Face dataset](https://huggingface.co/datasets/YiYiXu/testing-images/tree/main/modular_quicktour). +Here are the results from our modular pipeline examples. #### Base Text-to-Image Generation | Base Text-to-Image | Base Text-to-Image (Refined) | @@ -1387,7 +1387,7 @@ Here are the results from our modular pipeline examples. You can find all the ge |-------------------|------------------------------| | ![LoRA + IP-Adapter](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/modular_part2_ipa.png) | ![LoRA + IP-Adapter Refined](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/modular_part2_ipa_refine_out.png) | -### ControlNet + LoRA + IP-Adapter +#### ControlNet + LoRA + IP-Adapter | ControlNet + LoRA + IP-Adapter | ControlNet + LoRA + IP-Adapter (Refined) | |-------------------|------------------------------| | ![ControlNet + LoRA + IP-Adapter](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/modular_part2_control.png) | ![ControlNet + LoRA + IP-Adapter Refined](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/modular_part2_control_refine_out.png) | From fdd2bedae9bbff4e23a4e36d916ce9ac6fb60cf5 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 29 Jun 2025 03:00:46 +0200 Subject: [PATCH 120/170] 2024 -> 2025; fix a circular import --- docs/source/en/modular_diffusers/developer_guide.md | 2 +- docs/source/en/modular_diffusers/quicktour.md | 2 +- src/diffusers/guiders/__init__.py | 2 +- src/diffusers/guiders/adaptive_projected_guidance.py | 2 +- src/diffusers/guiders/auto_guidance.py | 2 +- src/diffusers/guiders/classifier_free_guidance.py | 2 +- src/diffusers/guiders/classifier_free_zero_star_guidance.py | 2 +- src/diffusers/guiders/guider_utils.py | 2 +- src/diffusers/guiders/skip_layer_guidance.py | 2 +- src/diffusers/guiders/smoothed_energy_guidance.py | 2 +- src/diffusers/guiders/tangential_classifier_free_guidance.py | 2 +- src/diffusers/hooks/_common.py | 2 +- src/diffusers/hooks/_helpers.py | 2 +- src/diffusers/hooks/layer_skip.py | 2 +- src/diffusers/hooks/smoothed_energy_guidance_utils.py | 2 +- src/diffusers/loaders/peft.py | 4 +++- src/diffusers/loaders/unet.py | 3 ++- src/diffusers/modular_pipelines/components_manager.py | 2 +- src/diffusers/modular_pipelines/modular_pipeline.py | 2 +- .../modular_pipelines/stable_diffusion_xl/before_denoise.py | 2 +- .../modular_pipelines/stable_diffusion_xl/decoders.py | 2 +- .../modular_pipelines/stable_diffusion_xl/denoise.py | 2 +- .../modular_pipelines/stable_diffusion_xl/encoders.py | 2 +- .../stable_diffusion_xl/modular_blocks_presets.py | 2 +- .../modular_pipelines/stable_diffusion_xl/modular_loader.py | 2 +- 25 files changed, 28 insertions(+), 25 deletions(-) diff --git a/docs/source/en/modular_diffusers/developer_guide.md b/docs/source/en/modular_diffusers/developer_guide.md index a4d8337840fb..21c278272cdc 100644 --- a/docs/source/en/modular_diffusers/developer_guide.md +++ b/docs/source/en/modular_diffusers/developer_guide.md @@ -1,4 +1,4 @@ - + +# `ModularPipelineBlocks` + +In Modular Diffusers, you build your workflow using `ModularPipelineBlocks`. We support 4 different types of blocks: `PipelineBlock`, `SequentialPipelineBlocks`, `LoopSequentialPipelineBlocks`, and `AutoPipelineBlocks`. Among them, `PipelineBlock` is the most fundamental building block of the whole system - it's like a brick in a Lego system. These blocks are designed to easily connect with each other, allowing for modular construction of creative and potentially very complex workflows. + +In this tutorial, we will focus on how to write a basic `PipelineBlock` and how it interacts with other components in the system. We will also cover how to connect them together using the multi-blocks: `SequentialPipelineBlocks`, `LoopSequentialPipelineBlocks`, and `AutoPipelineBlocks`. + + +## Understanding the Foundation: `PipelineState` + +Before we dive into creating `PipelineBlock`s, we need to have a basic understanding of `PipelineState` - the core data structure that all blocks operate on. This concept is fundamental to understanding how blocks interact with each other and the pipeline system. + +## `PipelineState` + +In the modular diffusers system, `PipelineState` acts as the global state container that `PipelineBlock`s operate on - each block gets a local view (`BlockState`) of the relevant variables it needs from `PipelineState`, performs its operations, and then updates `PipelineState` with any changes. + +While `PipelineState` maintains the complete runtime state of the pipeline, `PipelineBlock`s define what parts of that state they can read from and write to through their `input`s, `intermediates_inputs`, and `intermediates_outputs` properties. + +A `PipelineState` consists of two distinct states: +- The **immutable state** (i.e. the `inputs` dict) contains a copy of values provided by users. Once a value is added to the immutable state, it cannot be changed. Blocks can read from the immutable state but cannot write to it. +- The **mutable state** (i.e. the `intermediates` dict) contains variables that are passed between blocks and can be modified by them. + +Here's an example of what a `PipelineState` looks like: + +``` +PipelineState( + inputs={ + prompt: 'a cat' + guidance_scale: 7.0 + num_inference_steps: 25 + }, + intermediates={ + prompt_embeds: Tensor(dtype=torch.float32, shape=torch.Size([1, 1, 1, 1])) + negative_prompt_embeds: None + }, +``` + +### Creating a `PipelineBlock` + +To write a `PipelineBlock` class, you need to define a few properties that determine how your block interacts with the pipeline state. Understanding these properties is crucial - they define what data your block can access and what it can produce. + +The three main properties you need to define are: +- `inputs`: Immutable values from the user that cannot be modified +- `intermediate_inputs`: Mutable values from previous blocks that can be read and modified +- `intermediate_outputs`: New values your block creates for subsequent blocks + +Let's explore each one and understand how they work with the pipeline state. + +**Inputs: Immutable User Values** + +Inputs are variables your block needs from the immutable pipeline state - these are user-provided values that cannot be modified by any block. You define them using `InputParam`: + +```py +user_inputs = [ + InputParam(name="image", type_hint="PIL.Image", description="raw input image to process") +] +``` + +When you list something as an input, you're saying "I need this value directly from the end user, and I will talk to them directly, telling them what I need in the 'description' field. They will provide it and it will come to me unchanged." + +This is especially useful for raw values that serve as the "source of truth" in your workflow. For example, with a raw image, many workflows require preprocessing steps like resizing that a previous block might have performed. But in many cases, you also want the raw PIL image. In some inpainting workflows, you need the original image to overlay with the generated result for better control and consistency. + +**Intermediate Inputs: Mutable Values from Previous Blocks** + +Intermediate inputs are variables your block needs from the mutable pipeline state - these are values that can be read and modified. They're typically created by previous blocks, but could also be directly provided by the user if not the case: + +```py +user_intermediate_inputs = [ + InputParam(name="processed_image", type_hint="torch.Tensor", description="image that has been preprocessed and normalized"), +] +``` + +When you list something as an intermediate input, you're saying "I need this value, but I want to work with a different block that has already created it. I already know for sure that I can get it from this other block, but it's okay if other developers can use something different." + +**Intermediate Outputs: New Values for Subsequent Blocks** + +Intermediate outputs are new variables your block creates and adds to the mutable pipeline state so they can be used by subsequent blocks: + +```py +user_intermediate_outputs = [ + OutputParam(name="image_latents", description="latents representing the image") +] +``` + +Intermediate inputs and intermediate outputs work together like Lego studs and anti-studs - they're the connection points that make blocks modular. When one block produces an intermediate output, it becomes available as an intermediate input for subsequent blocks. This is where the "modular" nature of the system really shines - blocks can be connected and reconnected in different ways as long as their inputs and outputs match. We will see more how they connect when we talk about multi-blocks. + +**The `__call__` Method Structure** + +Your `PipelineBlock`'s `__call__` method should follow this structure: + +```py +def __call__(self, components, state): + # Get a local view of the state variables this block needs + block_state = self.get_block_state(state) + + # Your computation logic here + # block_state contains all your inputs and intermediate_inputs + # You can access them like: block_state.image, block_state.processed_image + + # Update the pipeline state with your updated block_states + self.add_block_state(state, block_state) + return components, state +``` + +The `block_state` object contains all the variables you defined in `inputs` and `intermediate_inputs`, making them easily accessible for your computation. + +**Components and Configs** + +You can define the components and pipeline-level configs your block needs using `ComponentSpec` and `ConfigSpec`: + +```py +from diffusers import ComponentSpec, ConfigSpec + +# Define components your block needs +expected_components = [ + ComponentSpec(name="unet", type_hint=UNet2DConditionModel), + ComponentSpec(name="scheduler", type_hint=EulerDiscreteScheduler) +] + +# Define pipeline-level configs +expected_config = [ + ConfigSpec("force_zeros_for_empty_prompt", True) +] +``` + +**Components**: You must provide a `name` and ideally a `type_hint`. The actual loading details (repo, subfolder, variant) are typically specified when creating the pipeline, as we covered in the [quicktour](quicktour.md#loading-components-into-a-modularpipeline). + +**Configs**: Simple pipeline-level settings that control behavior across all blocks. + +When you convert your blocks into a pipeline using `blocks.init_pipeline()`, the pipeline collects all component requirements from the blocks and fetches the loading specs from the modular repository. The components are then made available to your block in the `components` argument of the `__call__` method. + +That's all you need to define in order to create a `PipelineBlock`. There is no hidden complexity. In fact we are going to create a helper function that take exactly these variables as input and return a pipeline block. We will use this helper function through out the tutorial to create test blocks + +Note that for `__call__` method, the only part you should implement differently is the part between `get_block_state` and `add_block_state`, which can be abstracted into a simple function that takes `block_state` and returns the updated state. This is why our helper function accepts a `block_fn` parameter that does exactly that. + +**Helper Function** + +```py +from diffusers.modular_pipelines import PipelineBlock, InputParam, OutputParam +import torch + +def make_block(inputs=[], intermediate_inputs=[], intermediate_outputs=[], block_fn=None, description=None): + class TestBlock(PipelineBlock): + model_name = "test" + + @property + def inputs(self): + return inputs + + @property + def intermediate_inputs(self): + return intermediate_inputs + + @property + def intermediate_outputs(self): + return intermediate_outputs + + @property + def description(self): + return description if description is not None else "" + + def __call__(self, components, state): + block_state = self.get_block_state(state) + if block_fn is not None: + block_state = block_fn(block_state, state) + self.add_block_state(state, block_state) + return components, state + + return TestBlock() +``` + + +Let's create a simple block to see how these definitions interact with the pipeline state: + +```py +user_inputs = [ + InputParam(name="image", type_hint="PIL.Image", description="raw input image to process") +] + +user_intermediate_inputs = [InputParam(name="batch_size", type_hint=int)] + +user_intermediate_outputs = [ + OutputParam(name="image_latents", description="latents representing the image") +] + +def user_block_fn(block_state, pipeline_state): + print(f"pipeline_state (before update): {pipeline_state}") + print(f"block_state (before update): {block_state}") + + # Simulate processing the image + block_state.image = torch.randn(1, 3, 512, 512) + block_state.batch_size = block_state.batch_size * 2 + block_state.processed_image = [torch.randn(1, 3, 512, 512)] * block_state.batch_size + block_state.image_latents = torch.randn(1, 4, 64, 64) + + print(f"block_state (after update): {block_state}") + return block_state + +# Create a block with our definitions +block = make_block( + inputs=user_inputs, + intermediate_inputs=user_intermediate_inputs, + intermediate_outputs=user_intermediate_outputs, + block_fn=user_block_fn +) +pipe = block.init_pipeline() +``` + +Let's check the pipeline's docstring to see what inputs it expects: + +```py +>>> print(pipe.doc) +class TestBlock + + Inputs: + + image (`PIL.Image`, *optional*): + raw input image to process + + batch_size (`int`, *optional*): + + Outputs: + + image_latents (`None`): + latents representing the image +``` + +Notice that `batch_size` appears as an input even though we defined it as an intermediate input. This happens because no previous block provided it, so the pipeline makes it available as a user input. However, unlike regular inputs, this value goes directly into the mutable intermediate state. + +Now let's run the pipeline: + +```py +from diffusers.utils import load_image + +image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/image_of_squirrel_painting.png") +state = pipe(image=image, batch_size=2) +print(f"pipeline_state (after update): {state}") +``` + +```out +pipeline_state (before update): PipelineState( + inputs={ + image: + }, + intermediates={ + batch_size: 2 + }, +) +block_state (before update): BlockState( + image: + batch_size: 2 +) + +block_state (after update): BlockState( + image: Tensor(dtype=torch.float32, shape=torch.Size([1, 3, 512, 512])) + batch_size: 4 + processed_image: List[4] of Tensors with shapes [torch.Size([1, 3, 512, 512]), torch.Size([1, 3, 512, 512]), torch.Size([1, 3, 512, 512]), torch.Size([1, 3, 512, 512])] + image_latents: Tensor(dtype=torch.float32, shape=torch.Size([1, 4, 64, 64])) +) +pipeline_state (after update): PipelineState( + inputs={ + image: + }, + intermediates={ + batch_size: 4 + image_latents: Tensor(dtype=torch.float32, shape=torch.Size([1, 4, 64, 64])) + }, +) +``` + +**Key Observations:** + +1. **Before the update**: `image` goes to the immutable inputs dict, while `batch_size` goes to the mutable intermediates dict, and both are available in `block_state`. + +2. **After the update**: + - **`image` modification**: Changed in `block_state` but not in `pipeline_state` - this change is local to the block only + - **`batch_size` modification**: Updated in both `block_state` and `pipeline_state` - this change affects subsequent blocks (we didn't need to declare it as an intermediate output since it was already in the intermediates dict) + - **`image_latents` creation**: Added to `pipeline_state` because it was declared as an intermediate output + - **`processed_image` creation**: Not added to `pipeline_state` because it wasn't declared as an intermediate output + +I hope by now you have a basic idea about how `PipelineBlock` manages state through inputs, intermediate inputs, and intermediate outputs. The real power comes when we connect multiple blocks together - their intermediate outputs become intermediate inputs for subsequent blocks, creating modular workflows. Let's explore how to build these connections using multi-blocks like `SequentialPipelineBlocks`. From 9fae3828a70d1e8c218c44fe72c3e07c2612f89c Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Sun, 29 Jun 2025 14:49:31 -1000 Subject: [PATCH 122/170] Apply suggestions from code review --- docs/source/en/modular_diffusers/write_own_pipeline_block.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/source/en/modular_diffusers/write_own_pipeline_block.md b/docs/source/en/modular_diffusers/write_own_pipeline_block.md index a359c759edac..5c65dabdc180 100644 --- a/docs/source/en/modular_diffusers/write_own_pipeline_block.md +++ b/docs/source/en/modular_diffusers/write_own_pipeline_block.md @@ -17,11 +17,10 @@ In Modular Diffusers, you build your workflow using `ModularPipelineBlocks`. We In this tutorial, we will focus on how to write a basic `PipelineBlock` and how it interacts with other components in the system. We will also cover how to connect them together using the multi-blocks: `SequentialPipelineBlocks`, `LoopSequentialPipelineBlocks`, and `AutoPipelineBlocks`. -## Understanding the Foundation: `PipelineState` +### Understanding the Foundation: `PipelineState` Before we dive into creating `PipelineBlock`s, we need to have a basic understanding of `PipelineState` - the core data structure that all blocks operate on. This concept is fundamental to understanding how blocks interact with each other and the pipeline system. -## `PipelineState` In the modular diffusers system, `PipelineState` acts as the global state container that `PipelineBlock`s operate on - each block gets a local view (`BlockState`) of the relevant variables it needs from `PipelineState`, performs its operations, and then updates `PipelineState` with any changes. From b43e703fae38f77f6dfa67bdfba0700eca7258f0 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Sun, 29 Jun 2025 14:49:54 -1000 Subject: [PATCH 123/170] Update docs/source/en/modular_diffusers/write_own_pipeline_block.md --- docs/source/en/modular_diffusers/write_own_pipeline_block.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/source/en/modular_diffusers/write_own_pipeline_block.md b/docs/source/en/modular_diffusers/write_own_pipeline_block.md index 5c65dabdc180..7c621eee95d0 100644 --- a/docs/source/en/modular_diffusers/write_own_pipeline_block.md +++ b/docs/source/en/modular_diffusers/write_own_pipeline_block.md @@ -21,7 +21,6 @@ In this tutorial, we will focus on how to write a basic `PipelineBlock` and how Before we dive into creating `PipelineBlock`s, we need to have a basic understanding of `PipelineState` - the core data structure that all blocks operate on. This concept is fundamental to understanding how blocks interact with each other and the pipeline system. - In the modular diffusers system, `PipelineState` acts as the global state container that `PipelineBlock`s operate on - each block gets a local view (`BlockState`) of the relevant variables it needs from `PipelineState`, performs its operations, and then updates `PipelineState` with any changes. While `PipelineState` maintains the complete runtime state of the pipeline, `PipelineBlock`s define what parts of that state they can read from and write to through their `input`s, `intermediates_inputs`, and `intermediates_outputs` properties. From c75b88f86f870e3f3d5942400fc23a431d096a1d Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 30 Jun 2025 03:23:44 +0200 Subject: [PATCH 124/170] up --- .../write_own_pipeline_block.md | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/source/en/modular_diffusers/write_own_pipeline_block.md b/docs/source/en/modular_diffusers/write_own_pipeline_block.md index 7c621eee95d0..656bc59e4044 100644 --- a/docs/source/en/modular_diffusers/write_own_pipeline_block.md +++ b/docs/source/en/modular_diffusers/write_own_pipeline_block.md @@ -79,7 +79,7 @@ user_intermediate_inputs = [ ] ``` -When you list something as an intermediate input, you're saying "I need this value, but I want to work with a different block that has already created it. I already know for sure that I can get it from this other block, but it's okay if other developers can use something different." +When you list something as an intermediate input, you're saying "I need this value, but I want to work with a different block that has already created it. I already know for sure that I can get it from this other block, but it's okay if other developers want use something different." **Intermediate Outputs: New Values for Subsequent Blocks** @@ -132,7 +132,7 @@ expected_config = [ ] ``` -**Components**: You must provide a `name` and ideally a `type_hint`. The actual loading details (repo, subfolder, variant) are typically specified when creating the pipeline, as we covered in the [quicktour](quicktour.md#loading-components-into-a-modularpipeline). +**Components**: You must provide a `name` and ideally a `type_hint`. The actual loading details (`repo`, `subfolder`, `variant` and `revision` fields) are typically specified when creating the pipeline, as we covered in the [quicktour](quicktour.md#loading-components-into-a-modularpipeline). **Configs**: Simple pipeline-level settings that control behavior across all blocks. @@ -140,7 +140,7 @@ When you convert your blocks into a pipeline using `blocks.init_pipeline()`, the That's all you need to define in order to create a `PipelineBlock`. There is no hidden complexity. In fact we are going to create a helper function that take exactly these variables as input and return a pipeline block. We will use this helper function through out the tutorial to create test blocks -Note that for `__call__` method, the only part you should implement differently is the part between `get_block_state` and `add_block_state`, which can be abstracted into a simple function that takes `block_state` and returns the updated state. This is why our helper function accepts a `block_fn` parameter that does exactly that. +Note that for `__call__` method, the only part you should implement differently is the part between `self.get_block_state()` and `self.add_block_state()`, which can be abstracted into a simple function that takes `block_state` and returns the updated state. Our helper function accepts a `block_fn` that does exactly that. **Helper Function** @@ -179,7 +179,7 @@ def make_block(inputs=[], intermediate_inputs=[], intermediate_outputs=[], block ``` -Let's create a simple block to see how these definitions interact with the pipeline state: +Let's create a simple block to see how these definitions interact with the pipeline state. To better understand what's happening, we'll print out the states before and after updates to inspect them: ```py user_inputs = [ @@ -279,12 +279,12 @@ pipeline_state (after update): PipelineState( **Key Observations:** -1. **Before the update**: `image` goes to the immutable inputs dict, while `batch_size` goes to the mutable intermediates dict, and both are available in `block_state`. +1. **Before the update**: `image` (the input) goes to the immutable inputs dict, while `batch_size` (the intermediate_input) goes to the mutable intermediates dict, and both are available in `block_state`. 2. **After the update**: - - **`image` modification**: Changed in `block_state` but not in `pipeline_state` - this change is local to the block only - - **`batch_size` modification**: Updated in both `block_state` and `pipeline_state` - this change affects subsequent blocks (we didn't need to declare it as an intermediate output since it was already in the intermediates dict) - - **`image_latents` creation**: Added to `pipeline_state` because it was declared as an intermediate output - - **`processed_image` creation**: Not added to `pipeline_state` because it wasn't declared as an intermediate output + - **`image` (inputs)** changed in `block_state` but not in `pipeline_state` - this change is local to the block only. + - **`batch_size (intermediate_inputs)`** was updated in both `block_state` and `pipeline_state` - this change affects subsequent blocks (we didn't need to declare it as an intermediate output since it was already in the intermediates dict) + - **`image_latents (intermediate_outputs)`** was added to `pipeline_state` because it was declared as an intermediate output + - **`processed_image`** was not added to `pipeline_state` because it wasn't declared as an intermediate output I hope by now you have a basic idea about how `PipelineBlock` manages state through inputs, intermediate inputs, and intermediate outputs. The real power comes when we connect multiple blocks together - their intermediate outputs become intermediate inputs for subsequent blocks, creating modular workflows. Let's explore how to build these connections using multi-blocks like `SequentialPipelineBlocks`. From 285f8776202c26a04b086abfa59774f14035f2d6 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 30 Jun 2025 07:48:26 +0200 Subject: [PATCH 125/170] make InsertableDict importable from modular_pipelines --- src/diffusers/modular_pipelines/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index 16f0becd8850..9b18c8b048f9 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -38,6 +38,7 @@ "ConfigSpec", "InputParam", "OutputParam", + "InsertableDict", ] _import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularLoader"] _import_structure["components_manager"] = ["ComponentsManager"] @@ -65,6 +66,7 @@ ComponentSpec, ConfigSpec, InputParam, + InsertableDict, OutputParam, ) from .stable_diffusion_xl import ( From f09b1ccfaebb3af7221997855a77dd2e1ffdbe77 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 30 Jun 2025 07:48:44 +0200 Subject: [PATCH 126/170] start the section on sequential pipelines --- .../write_own_pipeline_block.md | 140 ++++++++++++------ 1 file changed, 96 insertions(+), 44 deletions(-) diff --git a/docs/source/en/modular_diffusers/write_own_pipeline_block.md b/docs/source/en/modular_diffusers/write_own_pipeline_block.md index 656bc59e4044..62a55467cd6b 100644 --- a/docs/source/en/modular_diffusers/write_own_pipeline_block.md +++ b/docs/source/en/modular_diffusers/write_own_pipeline_block.md @@ -17,7 +17,7 @@ In Modular Diffusers, you build your workflow using `ModularPipelineBlocks`. We In this tutorial, we will focus on how to write a basic `PipelineBlock` and how it interacts with other components in the system. We will also cover how to connect them together using the multi-blocks: `SequentialPipelineBlocks`, `LoopSequentialPipelineBlocks`, and `AutoPipelineBlocks`. -### Understanding the Foundation: `PipelineState` +## Understanding the Foundation: `PipelineState` Before we dive into creating `PipelineBlock`s, we need to have a basic understanding of `PipelineState` - the core data structure that all blocks operate on. This concept is fundamental to understanding how blocks interact with each other and the pipeline system. @@ -44,7 +44,7 @@ PipelineState( }, ``` -### Creating a `PipelineBlock` +## Creating a `PipelineBlock` To write a `PipelineBlock` class, you need to define a few properties that determine how your block interacts with the pipeline state. Understanding these properties is crucial - they define what data your block can access and what it can produce. @@ -182,17 +182,17 @@ def make_block(inputs=[], intermediate_inputs=[], intermediate_outputs=[], block Let's create a simple block to see how these definitions interact with the pipeline state. To better understand what's happening, we'll print out the states before and after updates to inspect them: ```py -user_inputs = [ +inputs = [ InputParam(name="image", type_hint="PIL.Image", description="raw input image to process") ] -user_intermediate_inputs = [InputParam(name="batch_size", type_hint=int)] +intermediate_inputs = [InputParam(name="batch_size", type_hint=int)] -user_intermediate_outputs = [ +intermediate_outputs = [ OutputParam(name="image_latents", description="latents representing the image") ] -def user_block_fn(block_state, pipeline_state): +def image_encoder_block_fn(block_state, pipeline_state): print(f"pipeline_state (before update): {pipeline_state}") print(f"block_state (before update): {block_state}") @@ -206,21 +206,23 @@ def user_block_fn(block_state, pipeline_state): return block_state # Create a block with our definitions -block = make_block( - inputs=user_inputs, - intermediate_inputs=user_intermediate_inputs, - intermediate_outputs=user_intermediate_outputs, - block_fn=user_block_fn +image_encoder_block = make_block( + inputs=inputs, + intermediate_inputs=intermediate_inputs, + intermediate_outputs=intermediate_outputs, + block_fn=image_encoder_block_fn, + description=" Encode raw image into its latent presentation" ) -pipe = block.init_pipeline() +pipe = image_encoder_block.init_pipeline() ``` Let's check the pipeline's docstring to see what inputs it expects: - ```py >>> print(pipe.doc) class TestBlock + Encode raw image into its latent presentation + Inputs: image (`PIL.Image`, *optional*): @@ -246,37 +248,6 @@ state = pipe(image=image, batch_size=2) print(f"pipeline_state (after update): {state}") ``` -```out -pipeline_state (before update): PipelineState( - inputs={ - image: - }, - intermediates={ - batch_size: 2 - }, -) -block_state (before update): BlockState( - image: - batch_size: 2 -) - -block_state (after update): BlockState( - image: Tensor(dtype=torch.float32, shape=torch.Size([1, 3, 512, 512])) - batch_size: 4 - processed_image: List[4] of Tensors with shapes [torch.Size([1, 3, 512, 512]), torch.Size([1, 3, 512, 512]), torch.Size([1, 3, 512, 512]), torch.Size([1, 3, 512, 512])] - image_latents: Tensor(dtype=torch.float32, shape=torch.Size([1, 4, 64, 64])) -) -pipeline_state (after update): PipelineState( - inputs={ - image: - }, - intermediates={ - batch_size: 4 - image_latents: Tensor(dtype=torch.float32, shape=torch.Size([1, 4, 64, 64])) - }, -) -``` - **Key Observations:** 1. **Before the update**: `image` (the input) goes to the immutable inputs dict, while `batch_size` (the intermediate_input) goes to the mutable intermediates dict, and both are available in `block_state`. @@ -288,3 +259,84 @@ pipeline_state (after update): PipelineState( - **`processed_image`** was not added to `pipeline_state` because it wasn't declared as an intermediate output I hope by now you have a basic idea about how `PipelineBlock` manages state through inputs, intermediate inputs, and intermediate outputs. The real power comes when we connect multiple blocks together - their intermediate outputs become intermediate inputs for subsequent blocks, creating modular workflows. Let's explore how to build these connections using multi-blocks like `SequentialPipelineBlocks`. + +## Create a `SequentialPipelineBlocks` + +I think by this point, you're already familiar with `SequentialPipelineBlocks` and how to create them with the `from_blocks_dict` API. It's one of the most common ways to use Modular Diffusers, and we've covered it pretty well in the [quicktour](https://moon-ci-docs.huggingface.co/docs/diffusers/pr_9672/en/modular_diffusers/quicktour#modularpipelineblocks). + +But how do blocks actually connect and work together? Understanding this is crucial for building effective modular workflows. Let's explore this through an example. + +**How Blocks Connect in SequentialPipelineBlocks:** + +The key insight is that blocks connect through their intermediate inputs and outputs - the "studs and anti-studs" we discussed earlier. Let's expand on our example to create a new block that produces `batch_size`, which we'll call "input_block": + +```py +def input_block_fn(block_state, pipeline_state): + + # Simulate processing the image + if not isinstance(block_state.prompt, list): + prompt = [block_state.prompt] + batch_size = len(block_state.prompt) + block_state.batch_size = batch_size * block_state.num_images_per_prompt + + return block_state + +input_block = make_block( + inputs=[ + InputParam(name="prompt", type_hint=list, description="list of text prompts"), + InputParam(name="num_images_per_prompt", type_hint=int, description="number of images per prompt") + ], + intermediate_outputs=[ + OutputParam(name="batch_size", description="calculated batch size") + ], + block_fn=input_block_fn, + description="A block that determines batch_size based on the number of prompts and num_images_per_prompt argument." +) +``` + +Now let's connect these blocks to create a pipeline: + +```py +from diffusers.modular_pipelines import SequentialPipelineBlocks, InsertableDict +blocks_dict = InsertableDict() +blocks_dict["input"] = input_block +blocks_dict["image_encoder"] = image_encoder_block +blocks = SequentialPipelineBlocks.from_blocks_dict(blocks_dict) +pipeline = blocks.init_pipeline() +``` + +Now you have a pipeline with 2 blocks. When you inspect `pipeline.doc`, you can see that `batch_size` is not listed as an input. The pipeline automatically detects that the `input_block` can produce `batch_size` for the `image_encoder_block`, so it doesn't ask the user to provide it. + +```py +>>> print(pipeline.doc) +class SequentialPipelineBlocks + + Inputs: + + prompt (`None`, *optional*): + + num_images_per_prompt (`None`, *optional*): + + image (`PIL.Image`, *optional*): + raw input image to process + + Outputs: + + batch_size (`None`): + + image_latents (`None`): + latents representing the image +``` + +At runtime, you have data flow like this: + +![Data Flow Diagram](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/sequential_mermaid.png) + +**How SequentialPipelineBlocks Works:** + +1. **Execution Order**: Blocks are executed in the order they're registered in the `blocks_dict` +2. **Data Flow**: Outputs from one block become available as intermediate inputs to all subsequent blocks +3. **Smart Input Resolution**: The pipeline automatically figures out which values need to be provided by the user and which will be generated by previous blocks +4. **Consistent Interface**: Each block maintains its own behavior and operates through its defined interface, while collectively these interfaces determine what the entire pipeline accepts and produces + +What happens within each block follows the same pattern we described earlier: each block gets its own `block_state` with the relevant inputs and intermediate inputs, performs its computation, and updates the pipeline state with its intermediate outputs. From c5849ba9d580e789e5662c6825f3e582961ee428 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 30 Jun 2025 09:46:34 +0200 Subject: [PATCH 127/170] more --- docs/source/en/modular_diffusers/write_own_pipeline_block.md | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/docs/source/en/modular_diffusers/write_own_pipeline_block.md b/docs/source/en/modular_diffusers/write_own_pipeline_block.md index 62a55467cd6b..71b1ef7f1951 100644 --- a/docs/source/en/modular_diffusers/write_own_pipeline_block.md +++ b/docs/source/en/modular_diffusers/write_own_pipeline_block.md @@ -273,9 +273,6 @@ The key insight is that blocks connect through their intermediate inputs and out ```py def input_block_fn(block_state, pipeline_state): - # Simulate processing the image - if not isinstance(block_state.prompt, list): - prompt = [block_state.prompt] batch_size = len(block_state.prompt) block_state.batch_size = batch_size * block_state.num_images_per_prompt @@ -340,3 +337,5 @@ At runtime, you have data flow like this: 4. **Consistent Interface**: Each block maintains its own behavior and operates through its defined interface, while collectively these interfaces determine what the entire pipeline accepts and produces What happens within each block follows the same pattern we described earlier: each block gets its own `block_state` with the relevant inputs and intermediate inputs, performs its computation, and updates the pipeline state with its intermediate outputs. + +## `LoopSequentialPipelineBlocks` \ No newline at end of file From 363737ec4b52b09fd9349de320b80ba466aabcf7 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 30 Jun 2025 11:09:08 +0200 Subject: [PATCH 128/170] add loop sequential blocks --- .../write_own_pipeline_block.md | 178 +++++++++++++++++- .../modular_pipelines/modular_pipeline.py | 9 + 2 files changed, 181 insertions(+), 6 deletions(-) diff --git a/docs/source/en/modular_diffusers/write_own_pipeline_block.md b/docs/source/en/modular_diffusers/write_own_pipeline_block.md index 71b1ef7f1951..ea9d43cc9405 100644 --- a/docs/source/en/modular_diffusers/write_own_pipeline_block.md +++ b/docs/source/en/modular_diffusers/write_own_pipeline_block.md @@ -262,7 +262,7 @@ I hope by now you have a basic idea about how `PipelineBlock` manages state thro ## Create a `SequentialPipelineBlocks` -I think by this point, you're already familiar with `SequentialPipelineBlocks` and how to create them with the `from_blocks_dict` API. It's one of the most common ways to use Modular Diffusers, and we've covered it pretty well in the [quicktour](https://moon-ci-docs.huggingface.co/docs/diffusers/pr_9672/en/modular_diffusers/quicktour#modularpipelineblocks). +I assume that you're already familiar with `SequentialPipelineBlocks` and how to create them with the `from_blocks_dict` API. It's one of the most common ways to use Modular Diffusers, and we've covered it pretty well in the [Getting Started Guide](https://moon-ci-docs.huggingface.co/docs/diffusers/pr_9672/en/modular_diffusers/quicktour#modularpipelineblocks). But how do blocks actually connect and work together? Understanding this is crucial for building effective modular workflows. Let's explore this through an example. @@ -331,11 +331,177 @@ At runtime, you have data flow like this: **How SequentialPipelineBlocks Works:** -1. **Execution Order**: Blocks are executed in the order they're registered in the `blocks_dict` -2. **Data Flow**: Outputs from one block become available as intermediate inputs to all subsequent blocks -3. **Smart Input Resolution**: The pipeline automatically figures out which values need to be provided by the user and which will be generated by previous blocks -4. **Consistent Interface**: Each block maintains its own behavior and operates through its defined interface, while collectively these interfaces determine what the entire pipeline accepts and produces +1. Blocks are executed in the order they're registered in the `blocks_dict` +2. Outputs from one block become available as intermediate inputs to all subsequent blocks +3. The pipeline automatically figures out which values need to be provided by the user and which will be generated by previous blocks +4. Each block maintains its own behavior and operates through its defined interface, while collectively these interfaces determine what the entire pipeline accepts and produces What happens within each block follows the same pattern we described earlier: each block gets its own `block_state` with the relevant inputs and intermediate inputs, performs its computation, and updates the pipeline state with its intermediate outputs. -## `LoopSequentialPipelineBlocks` \ No newline at end of file +## `LoopSequentialPipelineBlocks` + +To create a loop in Modular Diffusers, you could use a single `PipelineBlock` like this: + +```python +class DenoiseLoop(PipelineBlock): + def __call__(self, components, state): + block_state = self.get_block_state(state) + for t in range(block_state.num_inference_steps): + # ... loop logic here + pass + self.add_block_state(state, block_state) + return components, state +``` + +Or you could create a `LoopSequentialPipelineBlocks`. The key difference is that with `LoopSequentialPipelineBlocks`, the loop itself is modular: you can add or remove blocks within the loop or reuse the same loop structure with different block combinations. + +It involves two parts: a **loop wrapper** and **loop blocks** + +* The **loop wrapper** (`LoopSequentialPipelineBlocks`) defines the loop structure, e.g. it defines the iteration variables, and loop configurations such as progress bar. + +* The **loop blocks** are basically standard pipeline blocks you add to the loop wrapper. + - they run sequentially for each iteration of the loop + - they receive the current iteration index as an additional parameter + - they share the same block_state throughout the entire loop + +Unlike regular `SequentialPipelineBlocks` where each block gets its own state, loop blocks share a single state that persists and evolves across iterations. + +We will build a simple loop block to demonstrate these concepts. Creating a loop block involves three steps: +1. defining the loop wrapper class +2. creating the loop blocks +3. adding the loop blocks to the loop wrapper class to create the loop wrapper instance + +**Step 1: Define the Loop Wrapper** + +To create a `LoopSequentialPipelineBlocks` class, you need to define: + +* `loop_inputs`: User input variables (equivalent to `PipelineBlock.inputs`) +* `loop_intermediate_inputs`: Intermediate variables needed from the mutable pipeline state (equivalent to `PipelineBlock.intermediates_inputs`) +* `loop_intermediate_outputs`: New intermediate variables this block will add to the mutable pipeline state (equivalent to `PipelineBlock.intermediates_outputs`) +* `__call__` method: Defines the loop structure and iteration logic + +Here is an example of a loop wrapper: + +```py +import torch +from diffusers.modular_pipelines import LoopSequentialPipelineBlocks, PipelineBlock, InputParam, OutputParam + +class LoopWrapper(LoopSequentialPipelineBlocks): + model_name = "test" + @property + def description(self): + return "I'm a loop!!" + @property + def loop_inputs(self): + return [InputParam(name="num_steps")] + @torch.no_grad() + def __call__(self, components, state): + block_state = self.get_block_state(state) + # Loop structure - can be customized to your needs + for i in range(block_state.num_steps): + # loop_step executes all registered blocks in sequence + components, block_state = self.loop_step(components, block_state, i=i) + self.add_block_state(state, block_state) + return components, state +``` + +**Step 2: Create Loop Blocks** + +Loop blocks are standard `PipelineBlock`s, but their `__call__` method works differently: +* It receives the iteration variable (e.g., `i`) passed by the loop wrapper +* It works directly with `block_state` instead of pipeline state +* No need to call `self.get_block_state()` or `self.add_block_state()` + +```py +class LoopBlock(PipelineBlock): + # this is used to identify the model family, we won't worry about it in this example + model_name = "test" + @property + def inputs(self): + return [InputParam(name="x")] + @property + def intermediate_outputs(self): + # outputs produced by this block + return [OutputParam(name="x")] + @property + def description(self): + return "I'm a block used inside the `LoopWrapper` class" + def __call__(self, components, block_state, i: int): + block_state.x += 1 + return components, block_state +``` + +**Step 3: Combine Everything** + +Finally, assemble your loop by adding the block(s) to the wrapper: + +```py +loop = LoopWrapper.from_blocks_dict({"block1": LoopBlock}) +``` + +Now you've created a loop with one step: + +```py +>>> loop +LoopWrapper( + Class: LoopSequentialPipelineBlocks + + Description: I'm a loop!! + + Sub-Blocks: + [0] block1 (LoopBlock) + Description: I'm a block used inside the `LoopWrapper` class + +) +``` + +It has two inputs: `x` (used at each step within the loop) and `num_steps` used to define the loop. + +```py +>>> print(loop.doc) +class LoopWrapper + + I'm a loop!! + + Inputs: + + x (`None`, *optional*): + + num_steps (`None`, *optional*): + + Outputs: + + x (`None`): +``` + +**Running the Loop:** + +```py +# run the loop +loop_pipeline = loop.init_pipeline() +x = loop_pipeline(num_steps=10, x=0, output="x") +assert x == 10 +``` + +**Adding Multiple Blocks:** + +We can add multiple blocks to run within each iteration. Let's run the loop block twice within each iteration: + +```py +loop = LoopWrapper.from_blocks_dict({"block1": LoopBlock(), "block2": LoopBlock}) +loop_pipeline = loop.init_pipeline() +x = loop_pipeline(num_steps=10, x=0, output="x") +assert x == 20 # Each iteration runs 2 blocks, so 10 iterations * 2 = 20 +``` + +**Key Differences from SequentialPipelineBlocks:** + +The main difference is that loop blocks share the same `block_state` across all iterations, allowing values to accumulate and evolve throughout the loop. Loop blocks could receive additional arguments (like the current iteration index) depending on the loop wrapper's implementation, since the wrapper defines how loop blocks are called. You can easily add, remove, or reorder blocks within the loop without changing the loop logic itself. + +The officially supported denoising loops in Modular Diffusers are implemented using `LoopSequentialPipelineBlocks`. You can explore the actual implementation to see how these concepts work in practice: + +```py +from diffusers.modular_pipelines.stable_diffusion_xl.denoise import StableDiffusionXLDenoiseStep +StableDiffusionXLDenoiseStep() +``` + diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index fe7c4df50d87..c9d578b830ea 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -1452,6 +1452,15 @@ def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "LoopSequentialPipelin A new LoopSequentialPipelineBlocks instance """ instance = cls() + + # Create instances if classes are provided + sub_blocks = InsertableDict() + for name, block in blocks_dict.items(): + if inspect.isclass(block): + sub_blocks[name] = block() + else: + sub_blocks[name] = block + instance.block_classes = [block.__class__ for block in blocks_dict.values()] instance.block_names = list(blocks_dict.keys()) instance.sub_blocks = blocks_dict From bbd93407810109412372074496688648e07f7ab4 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 30 Jun 2025 11:30:06 +0200 Subject: [PATCH 129/170] up --- .../write_own_pipeline_block.md | 27 +++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/docs/source/en/modular_diffusers/write_own_pipeline_block.md b/docs/source/en/modular_diffusers/write_own_pipeline_block.md index ea9d43cc9405..01c693a88d48 100644 --- a/docs/source/en/modular_diffusers/write_own_pipeline_block.md +++ b/docs/source/en/modular_diffusers/write_own_pipeline_block.md @@ -295,14 +295,37 @@ Now let's connect these blocks to create a pipeline: ```py from diffusers.modular_pipelines import SequentialPipelineBlocks, InsertableDict +# define a dict map block names to block class blocks_dict = InsertableDict() blocks_dict["input"] = input_block blocks_dict["image_encoder"] = image_encoder_block +# create the multi-block blocks = SequentialPipelineBlocks.from_blocks_dict(blocks_dict) +# convert it to a runnable pipeline pipeline = blocks.init_pipeline() ``` -Now you have a pipeline with 2 blocks. When you inspect `pipeline.doc`, you can see that `batch_size` is not listed as an input. The pipeline automatically detects that the `input_block` can produce `batch_size` for the `image_encoder_block`, so it doesn't ask the user to provide it. +Now you have a pipeline with 2 blocks. + +``py +>>> pipeline.blocks +SequentialPipelineBlocks( + Class: ModularPipelineBlocks + + Description: + + + Sub-Blocks: + [0] input (TestBlock) + Description: A block that determines batch_size based on the number of prompts and num_images_per_prompt argument. + + [1] image_encoder (TestBlock) + Description: Encode raw image into its latent presentation + +) +``` + +When you inspect `pipeline.doc`, you can see that `batch_size` is not listed as an input. The pipeline automatically detects that the `input_block` can produce `batch_size` for the `image_encoder_block`, so it doesn't ask the user to provide it. ```py >>> print(pipeline.doc) @@ -327,7 +350,7 @@ class SequentialPipelineBlocks At runtime, you have data flow like this: -![Data Flow Diagram](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/sequential_mermaid.png) +![Data Flow Diagram](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/Editor%20_%20Mermaid%20Chart-2025-06-30-092631.png) **How SequentialPipelineBlocks Works:** From 0138e176aca8b705f67f505ca669cbda5b407137 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 30 Jun 2025 21:05:12 +0200 Subject: [PATCH 130/170] remove the get_exeuction_blocks rec from AutoPipelineBlocks repr --- src/diffusers/modular_pipelines/modular_pipeline.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index c9d578b830ea..1ada2fc41473 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -815,10 +815,7 @@ def __repr__(self): header += "\n" header += " " + "=" * 100 + "\n" header += " This pipeline contains blocks that are selected at runtime based on inputs.\n" - header += f" Trigger Inputs: {self.trigger_inputs}\n" - # Get first trigger input as example - example_input = next(t for t in self.trigger_inputs if t is not None) - header += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n" + header += f" Trigger Inputs: {[inp for inp in self.trigger_inputs if inp is not None]}\n" header += " " + "=" * 100 + "\n\n" # Format description with proper indentation @@ -1178,7 +1175,7 @@ def __repr__(self): header += "\n" header += " " + "=" * 100 + "\n" header += " This pipeline contains blocks that are selected at runtime based on inputs.\n" - header += f" Trigger Inputs: {self.trigger_inputs}\n" + header += f" Trigger Inputs: {[inp for inp in self.trigger_inputs if inp is not None]}\n" # Get first trigger input as example example_input = next(t for t in self.trigger_inputs if t is not None) header += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n" From db4b54cfabd7d5590d010a641bf91c49d2678a12 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 30 Jun 2025 21:05:32 +0200 Subject: [PATCH 131/170] finish the autopipelines section! --- .../write_own_pipeline_block.md | 260 +++++++++++++++++- 1 file changed, 256 insertions(+), 4 deletions(-) diff --git a/docs/source/en/modular_diffusers/write_own_pipeline_block.md b/docs/source/en/modular_diffusers/write_own_pipeline_block.md index 01c693a88d48..45671b3ca502 100644 --- a/docs/source/en/modular_diffusers/write_own_pipeline_block.md +++ b/docs/source/en/modular_diffusers/write_own_pipeline_block.md @@ -175,7 +175,7 @@ def make_block(inputs=[], intermediate_inputs=[], intermediate_outputs=[], block self.add_block_state(state, block_state) return components, state - return TestBlock() + return TestBlock ``` @@ -206,13 +206,14 @@ def image_encoder_block_fn(block_state, pipeline_state): return block_state # Create a block with our definitions -image_encoder_block = make_block( +image_encoder_block_cls = make_block( inputs=inputs, intermediate_inputs=intermediate_inputs, intermediate_outputs=intermediate_outputs, block_fn=image_encoder_block_fn, description=" Encode raw image into its latent presentation" ) +image_encoder_block = image_encoder_block_cls() pipe = image_encoder_block.init_pipeline() ``` @@ -278,7 +279,7 @@ def input_block_fn(block_state, pipeline_state): return block_state -input_block = make_block( +input_block_cls = make_block( inputs=[ InputParam(name="prompt", type_hint=list, description="list of text prompts"), InputParam(name="num_images_per_prompt", type_hint=int, description="number of images per prompt") @@ -289,6 +290,7 @@ input_block = make_block( block_fn=input_block_fn, description="A block that determines batch_size based on the number of prompts and num_images_per_prompt argument." ) +input_block = input_block_cls() ``` Now let's connect these blocks to create a pipeline: @@ -307,7 +309,7 @@ pipeline = blocks.init_pipeline() Now you have a pipeline with 2 blocks. -``py +```py >>> pipeline.blocks SequentialPipelineBlocks( Class: ModularPipelineBlocks @@ -528,3 +530,253 @@ from diffusers.modular_pipelines.stable_diffusion_xl.denoise import StableDiffus StableDiffusionXLDenoiseStep() ``` +## `AutoPipelineBlocks` + +`AutoPipelineBlocks` allows you to pack different pipelines into one and automatically select which one to run at runtime based on the inputs. The main purpose is convenience and portability - for developers, you can package everything into one workflow, making it easier to share and use. + +For example, you might want to support text-to-image and image-to-image tasks. Instead of creating two separate pipelines, you can create an `AutoPipelineBlocks` that automatically chooses the workflow based on whether an `image` input is provided. + +Let's see an example. Here we'll create a dummy `AutoPipelineBlocks` that includes dummy text-to-image, image-to-image, and inpaint pipelines. + + +```py +from diffusers.modular_pipelines import AutoPipelineBlocks + +# These are dummy blocks and we only focus on "inputs" for our purpose +inputs = [InputParam(name="prompt")] +# block_fn prints out which workflow is running so we can see the execution order at runtime +block_fn = lambda x, y: print("running the text-to-image workflow") +block_t2i_cls = make_block(inputs=inputs, block_fn=block_fn, description="I'm a text-to-image workflow!") + +inputs = [InputParam(name="prompt"), InputParam(name="image")] +block_fn = lambda x, y: print("running the image-to-image workflow") +block_i2i_cls = make_block(inputs=inputs, block_fn=block_fn, description="I'm a image-to-image workflow!") + +inputs = [InputParam(name="prompt"), InputParam(name="image"), InputParam(name="mask")] +block_fn = lambda x, y: print("running the inpaint workflow") +block_inpaint_cls = make_block(inputs=inputs, block_fn=block_fn, description="I'm a inpaint workflow!") + +class AutoImageBlocks(AutoPipelineBlocks): + # List of sub-block classes to choose from + block_classes = [block_inpaint_cls, block_i2i_cls, block_t2i_cls] + # Names for each block in the same order + block_names = ["inpaint", "img2img", "text2img"] + # Trigger inputs that determine which block to run + # - "mask" triggers inpaint workflow + # - "image" triggers img2img workflow (but only if mask is not provided) + # - if none of above, runs the text2img workflow (default) + block_trigger_inputs = ["mask", "image", None] + # Description is extremely important for AutoPipelineBlocks + @property + def description(self): + return ( + "Pipeline generates images given different types of conditions!\n" + + "This is an auto pipeline block that works for text2img, img2img and inpainting tasks.\n" + + " - inpaint workflow is run when `mask` is provided.\n" + + " - img2img workflow is run when `image` is provided (but only when `mask` is not provided).\n" + + " - text2img workflow is run when neither `image` nor `mask` is provided.\n" + ) + +# Create the blocks +auto_blocks = AutoImageBlocks() +# convert to pipeline +auto_pipeline = auto_blocks.init_pipeline() +``` + +Now we have created an `AutoPipelineBlocks` that contains 3 sub-blocks. Notice the warning message at the top - this automatically appears in every `ModularPipelineBlocks` that contains `AutoPipelineBlocks` to remind end users that dynamic block selection happens at runtime. + +```py +AutoImageBlocks( + Class: AutoPipelineBlocks + + ==================================================================================================== + This pipeline contains blocks that are selected at runtime based on inputs. + Trigger Inputs: ['mask', 'image'] + ==================================================================================================== + + + Description: Pipeline generates images given different types of conditions! + This is an auto pipeline block that works for text2img, img2img and inpainting tasks. + - inpaint workflow is run when `mask` is provided. + - img2img workflow is run when `image` is provided (but only when `mask` is not provided). + - text2img workflow is run when neither `image` nor `mask` is provided. + + + + Sub-Blocks: + • inpaint [trigger: mask] (TestBlock) + Description: I'm a inpaint workflow! + + • img2img [trigger: image] (TestBlock) + Description: I'm a image-to-image workflow! + + • text2img [default] (TestBlock) + Description: I'm a text-to-image workflow! + +) +``` + +Check out the documentation with `print(auto_pipeline.doc)`: + +```py +>>> print(auto_pipeline.doc) +class AutoImageBlocks + + Pipeline generates images given different types of conditions! + This is an auto pipeline block that works for text2img, img2img and inpainting tasks. + - inpaint workflow is run when `mask` is provided. + - img2img workflow is run when `image` is provided (but only when `mask` is not provided). + - text2img workflow is run when neither `image` nor `mask` is provided. + + Inputs: + + prompt (`None`, *optional*): + + image (`None`, *optional*): + + mask (`None`, *optional*): +``` + +There is a fundamental trade-off of AutoPipelineBlocks: it trades clarity for convenience. While it is really easy for packaging multiple workflows, it can become confusing without proper documentation. e.g. if we just throw a pipeline at you and tell you that it contains 3 sub-blocks and takes 3 inputs `prompt`, `image` and `mask`, and ask you to run an image-to-image workflow: if you don't have any prior knowledge on how these pipelines work, you would be pretty clueless, right? + +This pipeline we just made though, has a docstring that shows all available inputs and workflows and explains how to use each with different inputs. So it's really helpful for users. For example, it's clear that you need to pass `image` to run img2img. This is why the description field is absolutely critical for AutoPipelineBlocks. We highly recommend you to explain the conditional logic very well for each `AutoPipelineBlocks` you would make. We also recommend to always test individual pipelines first before packaging them into AutoPipelineBlocks. + +Let's run this auto pipeline with different inputs to see if the conditional logic works as described. Remember that we have added `print` in each `PipelineBlock`'s `__call__` method to print out its workflow name, so it should be easy to tell which one is running: + +```py +>>> _ = auto_pipeline(image="image", mask="mask") +running the inpaint workflow +>>> _ = auto_pipeline(image="image") +running the image-to-image workflow +>>> _ = auto_pipeline(prompt="prompt") +running the text-to-image workflow +>>> _ = auto_pipeline(image="prompt", mask="mask") +running the inpaint workflow +``` + +However, even with documentation, it can become very confusing when AutoPipelineBlocks are combined with other blocks. The complexity grows quickly when you have nested AutoPipelineBlocks or use them as sub-blocks in larger pipelines. + +Let's make another `AutoPipelineBlocks` - this one only contains one block, and it does not include `None` in its `block_trigger_inputs` (which corresponds to the default block to run when none of the trigger inputs are provided). This means this block will be skipped if the trigger input (`ip_adapter_image`) is not provided at runtime. + +```py +from diffusers.modular_pipelines import SequentialPipelineBlocks, InsertableDict +inputs = [InputParam(name="ip_adapter_image")] +block_fn = lambda x, y: print("running the ip-adapter workflow") +block_ipa_cls = make_block(inputs=inputs, block_fn=block_fn, description="I'm a IP-adapter workflow!") + +class AutoIPAdapter(AutoPipelineBlocks): + block_classes = [block_ipa_cls] + block_names = ["ip-adapter"] + block_trigger_inputs = ["ip_adapter_image"] + @property + def description(self): + return "Run IP Adapter step if `ip_adapter_image` is provided. This step should be placed before the 'input' step.\n" +``` + +Now let's combine these 2 auto blocks together into a `SequentialPipelineBlocks`: + +```py +auto_ipa_blocks = AutoIPAdapter() +blocks_dict = InsertableDict() +blocks_dict["ip-adapter"] = auto_ipa_blocks +blocks_dict["image-generation"] = auto_blocks +all_blocks = SequentialPipelineBlocks.from_blocks_dict(blocks_dict) +pipeline = all_blocks.init_pipeline() +``` + +Let's take a look: now things get more confusing. In this particular example, you could still try to explain the conditional logic in the `description` field here - there are only 4 possible execution paths so it's doable. However, since this is a `SequentialPipelineBlocks` that could contain many more blocks, the complexity can quickly get out of hand as the number of blocks increases. + +```py +>>> all_blocks +SequentialPipelineBlocks( + Class: ModularPipelineBlocks + + ==================================================================================================== + This pipeline contains blocks that are selected at runtime based on inputs. + Trigger Inputs: ['image', 'mask', 'ip_adapter_image'] + Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('image')`). + ==================================================================================================== + + + Description: + + + Sub-Blocks: + [0] ip-adapter (AutoIPAdapter) + Description: Run IP Adapter step if `ip_adapter_image` is provided. This step should be placed before the 'input' step. + + + [1] image-generation (AutoImageBlocks) + Description: Pipeline generates images given different types of conditions! + This is an auto pipeline block that works for text2img, img2img and inpainting tasks. + - inpaint workflow is run when `mask` is provided. + - img2img workflow is run when `image` is provided (but only when `mask` is not provided). + - text2img workflow is run when neither `image` nor `mask` is provided. + + +) + +``` + +This is when the `get_execution_blocks()` method comes in handy - it basically extracts a `SequentialPipelineBlocks` that only contains the blocks that are actually run based on your inputs. + +Let's try some examples: + +`mask`: we expect it to skip the first ip-adapter since `ip_adapter_image` is not provided, and then run the inpaint for the second block. + +```py +>>> all_blocks.get_execution_blocks('mask') +SequentialPipelineBlocks( + Class: ModularPipelineBlocks + + Description: + + + Sub-Blocks: + [0] image-generation (TestBlock) + Description: I'm a inpaint workflow! + +) +``` + +Let's also actually run the pipeline to confirm: + +```py +>>> _ = pipeline(mask="mask") +skipping auto block: AutoIPAdapter +running the inpaint workflow +``` + +Try a few more: + +```py +print(f"inputs: ip_adapter_image:") +blocks_select = all_blocks.get_execution_blocks('ip_adapter_image') +print(f"expected_execution_blocks: {blocks_select}") +print(f"actual execution blocks:") +_ = pipeline(ip_adapter_image="ip_adapter_image", prompt="prompt") +# expect to see ip-adapter + text2img + +print(f"inputs: image:") +blocks_select = all_blocks.get_execution_blocks('image') +print(f"expected_execution_blocks: {blocks_select}") +print(f"actual execution blocks:") +_ = pipeline(image="image", prompt="prompt") +# expect to see img2img + +print(f"inputs: prompt:") +blocks_select = all_blocks.get_execution_blocks('prompt') +print(f"expected_execution_blocks: {blocks_select}") +print(f"actual execution blocks:") +_ = pipeline(prompt="prompt") +# expect to see text2img (prompt is not a trigger input so fallback to default) + +print(f"inputs: mask + ip_adapter_image:") +blocks_select = all_blocks.get_execution_blocks('mask','ip_adapter_image') +print(f"expected_execution_blocks: {blocks_select}") +print(f"actual execution blocks:") +_ = pipeline(mask="mask", ip_adapter_image="ip_adapter_image") +# expect to see ip-adapter + inpaint +``` + +In summary, `AutoPipelineBlocks` is a good tool for packaging multiple workflows into a single, convenient interface and it can greatly simplify the user experience. However, always provide clear descriptions explaining the conditional logic, test individual pipelines first before combining them, and use `get_execution_blocks()` to understand runtime behavior in complex compositions. \ No newline at end of file From abf28d55fb00c0f99e55cba8785e1eaf1757feda Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 30 Jun 2025 21:45:30 +0200 Subject: [PATCH 132/170] update --- .../write_own_pipeline_block.md | 33 +++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/docs/source/en/modular_diffusers/write_own_pipeline_block.md b/docs/source/en/modular_diffusers/write_own_pipeline_block.md index 45671b3ca502..a935380712ff 100644 --- a/docs/source/en/modular_diffusers/write_own_pipeline_block.md +++ b/docs/source/en/modular_diffusers/write_own_pipeline_block.md @@ -248,7 +248,36 @@ image = load_image("https://huggingface.co/datasets/huggingface/documentation-im state = pipe(image=image, batch_size=2) print(f"pipeline_state (after update): {state}") ``` +```out +pipeline_state (before update): PipelineState( + inputs={ + image: + }, + intermediates={ + batch_size: 2 + }, +) +block_state (before update): BlockState( + image: + batch_size: 2 +) +block_state (after update): BlockState( + image: Tensor(dtype=torch.float32, shape=torch.Size([1, 3, 512, 512])) + batch_size: 4 + processed_image: List[4] of Tensors with shapes [torch.Size([1, 3, 512, 512]), torch.Size([1, 3, 512, 512]), torch.Size([1, 3, 512, 512]), torch.Size([1, 3, 512, 512])] + image_latents: Tensor(dtype=torch.float32, shape=torch.Size([1, 4, 64, 64])) +) +pipeline_state (after update): PipelineState( + inputs={ + image: + }, + intermediates={ + batch_size: 4 + image_latents: Tensor(dtype=torch.float32, shape=torch.Size([1, 4, 64, 64])) + }, +) +``` **Key Observations:** 1. **Before the update**: `image` (the input) goes to the immutable inputs dict, while `batch_size` (the intermediate_input) goes to the mutable intermediates dict, and both are available in `block_state`. @@ -670,7 +699,7 @@ class AutoIPAdapter(AutoPipelineBlocks): block_trigger_inputs = ["ip_adapter_image"] @property def description(self): - return "Run IP Adapter step if `ip_adapter_image` is provided. This step should be placed before the 'input' step.\n" + return "Run IP Adapter step if `ip_adapter_image` is provided." ``` Now let's combine these 2 auto blocks together into a `SequentialPipelineBlocks`: @@ -703,7 +732,7 @@ SequentialPipelineBlocks( Sub-Blocks: [0] ip-adapter (AutoIPAdapter) - Description: Run IP Adapter step if `ip_adapter_image` is provided. This step should be placed before the 'input' step. + Description: Run IP Adapter step if `ip_adapter_image` is provided. [1] image-generation (AutoImageBlocks) From f27fbceba1829555452c852450f10a5895a240b0 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 30 Jun 2025 22:09:57 +0200 Subject: [PATCH 133/170] more attemp to fix circular import --- src/diffusers/loaders/lora_base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 562a21dbbb74..c072165dedc2 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -25,7 +25,6 @@ from huggingface_hub import model_info from huggingface_hub.constants import HF_HUB_OFFLINE -from ..hooks.group_offloading import _is_group_offload_enabled, _maybe_remove_and_reapply_group_offloading from ..models.modeling_utils import ModelMixin, load_state_dict from ..utils import ( USE_PEFT_BACKEND, @@ -331,6 +330,8 @@ def _load_lora_into_text_encoder( hotswap: bool = False, metadata=None, ): + from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading + if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -442,6 +443,8 @@ def _func_optionally_disable_offloading(_pipeline): tuple: A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` or `is_group_offload` is True. """ + from ..hooks.group_offloading import _is_group_offload_enabled + is_model_cpu_offload = False is_sequential_cpu_offload = False is_group_offload = False From b5db8aaa6fa3dbcc8cd7d11c8f9f4d587be75f5b Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 1 Jul 2025 03:05:38 +0200 Subject: [PATCH 134/170] developer_guide -> end-to-end guide --- ...developer_guide.md => end_to_end_guide.md} | 337 +++++++----------- 1 file changed, 127 insertions(+), 210 deletions(-) rename docs/source/en/modular_diffusers/{developer_guide.md => end_to_end_guide.md} (66%) diff --git a/docs/source/en/modular_diffusers/developer_guide.md b/docs/source/en/modular_diffusers/end_to_end_guide.md similarity index 66% rename from docs/source/en/modular_diffusers/developer_guide.md rename to docs/source/en/modular_diffusers/end_to_end_guide.md index 21c278272cdc..0784ce2e1fbd 100644 --- a/docs/source/en/modular_diffusers/developer_guide.md +++ b/docs/source/en/modular_diffusers/end_to_end_guide.md @@ -10,35 +10,40 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> -# Developer Guide: Building with Modular Diffusers +# End-to-End Developer Guide: Building with Modular Diffusers -[[open-in-colab]] In this tutorial we will walk through the process of adding a new pipeline to the modular framework using differential diffusion as our example. We'll cover the complete workflow from implementation to deployment: implementing the new pipeline, ensuring compatibility with existing tools, sharing the code on Hugging Face Hub, and deploying it as a UI node. -We'll also demonstrate the 3-step framework process we use for implementing new basic pipelines in the modular system. +We'll also demonstrate the 4-step framework process we use for implementing new basic pipelines in the modular system. -#### 1. **Start with an existing pipeline as a base** - - Identify which existing pipeline is most similar to your target - - Determine what part of the pipeline need modification +1. **Start with an existing pipeline as a base** + - Identify which existing pipeline is most similar to the one you want to implement + - Determine what part of the pipeline needs modification -#### 2. **Build a working pipeline structure first** +2. **Build a working pipeline structure first** - Assemble the complete pipeline structure - Use existing blocks wherever possible - For new blocks, create placeholders (e.g. you can copy from similar blocks and change the name) without implementing custom logic just yet -#### 3. **Set up an example and test incrementally** +3. **Set up an example** - Create a simple inference script with expected inputs/outputs - - Test incrementally as you implement changes + +4. **Implement your custom logic and test incrementally** + - Add the custom logics the blocks you want to change + - Test incrementally, and inspect pipeline states and debug as needed Let's see how this works with the Differential Diffusion example. ## Differential Diffusion Pipeline +### Start with an existing pipeline + Differential diffusion (https://differential-diffusion.github.io/) is an image-to-image workflow, so it makes sense for us to start with the preset of pipeline blocks used to build img2img pipeline (`IMAGE2IMAGE_BLOCKS`) and see how we can build this new pipeline with them. ```py +>>> from diffusers.modular_pipelines.stable_diffusion_xl import IMAGE2IMAGE_BLOCKS >>> IMAGE2IMAGE_BLOCKS = InsertableDict([ ... ("text_encoder", StableDiffusionXLTextEncoderStep), ... ("image_encoder", StableDiffusionXLVaeEncoderStep), @@ -46,12 +51,12 @@ Differential diffusion (https://differential-diffusion.github.io/) is an image-t ... ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), ... ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep), ... ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), -... ("denoise", StableDiffusionXLDenoiseLoop), +... ("denoise", StableDiffusionXLDenoiseStep), ... ("decode", StableDiffusionXLDecodeStep) ... ]) ``` -Note that "denoise" (`StableDiffusionXLDenoiseLoop`) is a loop that contains 3 loop blocks (more on SequentialLoopBlocks [here](https://colab.research.google.com/drive/1iVRjy_tOfmmm4gd0iVe0_Rl3c6cBzVqi?usp=sharing)) +Note that "denoise" (`StableDiffusionXLDenoiseStep`) is a `LoopSequentialPipelineBlocks` that contains 3 loop blocks (more on LoopSequentialPipelineBlocks [here](https://huggingface.co/docs/diffusers/modular_diffusers/write_own_pipeline_block#loopsequentialpipelineblocks)) ```py >>> denoise_blocks = IMAGE2IMAGE_BLOCKS["denoise"]() @@ -59,7 +64,7 @@ Note that "denoise" (`StableDiffusionXLDenoiseLoop`) is a loop that contains 3 l ``` ```out -StableDiffusionXLDenoiseLoop( +StableDiffusionXLDenoiseStep( Class: StableDiffusionXLDenoiseLoopWrapper Description: Denoise step that iteratively denoise the latents. @@ -68,7 +73,7 @@ StableDiffusionXLDenoiseLoop( - `StableDiffusionXLLoopBeforeDenoiser` - `StableDiffusionXLLoopDenoiser` - `StableDiffusionXLLoopAfterDenoiser` - + This block supports both text2img and img2img tasks. Components: @@ -76,7 +81,7 @@ StableDiffusionXLDenoiseLoop( guider (`ClassifierFreeGuidance`) unet (`UNet2DConditionModel`) - Blocks: + Sub-Blocks: [0] before_denoiser (StableDiffusionXLLoopBeforeDenoiser) Description: step within the denoising loop that prepare the latent input for the denoiser. This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`) @@ -89,18 +94,23 @@ StableDiffusionXLDenoiseLoop( ) ``` +Let's compare standard image-to-image and differential diffusion! The key difference in algorithm is that standard image-to-image diffusion applies uniform noise across all pixels based on a single `strength` parameter, but differential diffusion uses a change map where each pixel value determines when that region starts denoising. Regions with lower values get "frozen" earlier by replacing them with noised original latents, preserving more of the original image. -Img2img diffusion pipeline adds the same noise level across all pixels based on a single strength parameter, however, differential diffusion uses a change map where each pixel value represents when that region should start denoising. Regions with lower change map values get "frozen" earlier in the denoising process by replacing them with noised original latents, effectively giving them fewer denoising steps and thus preserving more of the original image. - -It has a different `prepare_latents` step and `denoise` step. At `parepare_latents` step, it prepares the change map and pre-computes the original noised latents for all timesteps. At each timestep during the denoising process, it selectively applies denoising based on the change map. Additionally, diff-diff does not use the `strengh` parameter, so its `set_timesteps` step is different from the one in image-to-image, but same as `set_timesteps` in text-to-image workflow. +Therefore, the key differences when it comes to pipeline implementation would be: +1. The `prepare_latents` step (which prepares the change map and pre-computes noised latents for all timesteps) +2. The `denoise` step (which selectively applies denoising based on the change map) +3. Since differential diffusion doesn't use the `strength` parameter, we'll use the text-to-image `set_timesteps` step instead of the image-to-image version -So, to implement the differential diffusion pipeline, we can use pipeline blocks from image-to-image and text-to-image workflow, and change the `prepare_latents` step and the `denoise` step (more specifically, we only need to change the first part of `denoise` step where we prepare the latent input for the denoiser model). +To implement differntial diffusion, we can reuse most blocks from image-to-image and text-to-image workflows, only modifying the `prepare_latents` step and the first part of the `denoise` step (i.e. `before_denoiser (StableDiffusionXLLoopBeforeDenoiser)`). -Differential diffusion shares exact same pipeline structure as img2img. Here is a flowchart that puts the changes we need to make into the context of the pipeline structure. +Here's a flowchart showing the pipeline structure and the changes we need to make: ![DiffDiff Pipeline Structure](https://mermaid.ink/img/pako:eNqVVO9r4kAQ_VeWLQWFKEk00eRDwZpa7Q-ucPfpYpE1mdWlcTdsVmpb-7_fZk1tTCl3J0Sy8968N5kZ9g0nIgUc4pUk-Rr9iuYc6d_Ibs14vlXoQYpNrtqo07lAo1jBTi2AlynysWIa6DJmG7KCBnZpsHHMSqkqNjaxKC5ALRTbQKEgLyosMthVnEvIiYRFRhRwVaBoNpmUT0W7MrTJkUbSdJEInlbwxMDXcQpcsAKq6OH_2mDTODIY4yt0J0ReUaYGnLXiJVChdSsB-enfPhBnhnjT-rCQj-1K_8Ygt62YUAVy8Ykf4FvU6XYu9rpuIGqPpvXSzs_RVEj2KrgiGUp02zNQTHBEM_FcK3BfQbBHd7qAst-PxvW-9WOrypnNylG0G9oRUMYBFeolg-IQTTJSFDqOUkZp-fwsQURZloVnlPpLf2kVSoonCM-SwCUuqY6dZ5aqddjLd1YiMiFLNrWorrxj9EOmP4El37lsl_9p5PzFqIqwVwgdN981fDM94bphH5I06R8NXZ_4QcPQPTFs6JltPrS6JssFhw9N817l27bdyM-lSKAo6iVBAAnQY0n9wLO9wbcluY7ruUFDtdguH74K0yENKDkK-8nAG6TfNrfy_bf-HjdrlOfZS7VYSAlU5JAwyhLE9WrWVw1dWdPTXauDsy8LUkdHtnX_pfMnBOvSGluRNbGurbuTHtdZN9Zts1MljC19_7EUh0puwcIbkBtSHvFbic6xWsMG5jjUrymRT3M85-86Jyf8txCbjzQptqs1DinJCn3a5qm-viJG9M26OUYlcH0_jsWWKxwGttHA4Rve4dD1el3H8_yh49hD3_X7roVfcNhx-l3b14PxvGHQ0xMa9t4t_Gp8na7tDvu-4w08HXecweD9D4X54ZI) + +### Build a Working Pipeline Structure + ok now we've identified the blocks to modify, let's build the pipeline skeleton first - at this stage, our goal is to get the pipeline struture working end-to-end (even though it's just doing the img2img behavior). I would simply create placeholder blocks by copying from existing ones: ```py @@ -114,10 +124,10 @@ ok now we've identified the blocks to modify, let's build the pipeline skeleton ... # ... same implementation as StableDiffusionXLLoopBeforeDenoiser ``` -`SDXLDiffDiffLoopBeforeDenoiser` is the be part of the denoise loop we need to change. Let's use it to assemble a `SDXLDiffDiffDenoiseLoop`. +`SDXLDiffDiffLoopBeforeDenoiser` is the be part of the denoise loop we need to change. Let's use it to assemble a `SDXLDiffDiffDenoiseStep`. ```py ->>> class SDXLDiffDiffDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): +>>> class SDXLDiffDiffDenoiseStep(StableDiffusionXLDenoiseLoopWrapper): ... block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLLoopDenoiser, StableDiffusionXLLoopAfterDenoiser] ... block_names = ["before_denoiser", "denoiser", "after_denoiser"] ``` @@ -128,18 +138,20 @@ Now we can put together our differential diffusion pipeline. >>> DIFFDIFF_BLOCKS = IMAGE2IMAGE_BLOCKS.copy() >>> DIFFDIFF_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"] >>> DIFFDIFF_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep ->>> DIFFDIFF_BLOCKS["denoise"] = SDXLDiffDiffDenoiseLoop +>>> DIFFDIFF_BLOCKS["denoise"] = SDXLDiffDiffDenoiseStep >>> >>> dd_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_BLOCKS) >>> print(dd_blocks) >>> # At this point, the pipeline works exactly like img2img since our blocks are just copies ``` -ok, so now our blocks should be able to compile without an error, we can move on to the next step. Let's setup a simple exapmple so we can run the pipeline as we build it. diff-diff use same components as SDXL so we can fetch the models from a regular SDXL repo. +### Set up an example + +ok, so now our blocks should be able to compile without an error, we can move on to the next step. Let's setup a simple example so we can run the pipeline as we build it. diff-diff use same model checkpoints as SDXL so we can fetch the models from a regular SDXL repo. ```py >>> dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff") ->>> dd_pipeline.load_componenets(torch_dtype=torch.float16) +>>> dd_pipeline.load_default_componenets(torch_dtype=torch.float16) >>> dd_pipeline.to("cuda") ``` @@ -167,12 +179,17 @@ We will use this example script: If you run the script right now, you will get a complaint about unexpected input `diffdiff_map`. and you would get the same result as the original img2img pipeline. +### implement your custom logic and test incrementally + Let's modify the pipeline so that we can get expected result with this example script. -We'll start with the `prepare_latents` step, as it is the first step that gets called right after the `input` step. Let's first apply changes in inputs/outputs/components. The main changes are: -- new input `diffdiff_map` -- new intermediates inputs `num_inference_steps` and `timestesp`. Both variables are already created in `set_timesteps` block, we can now need to use them in `prepare_latents` step. -- A new component `mask_processor` to process the `diffdiff_map` +We'll start with the `prepare_latents` step. The main changes are: +- Requires a new user input `diffdiff_map` +- Requires new component `mask_processor` to process the `diffdiff_map` +- Requires new intermediate inputs: + - Need `timestep` instead of `latent_timestep` to precompute all the latents + - Need `num_inference_steps` to create the `diffdiff_masks` +- create a new output `diffdiff_masks` and `original_latents` @@ -182,7 +199,7 @@ e.g. after we added `diffdiff_map` as an input in this step, we can run `print(d -Once we make sure all the variables we need are available in the block state, we can implement the diff-diff logic inside `__call__`. We created 2 new variables: the change map `diffdiff_mask` and the pre-computed noised latents for all timesteps `original_latents`. We also need to list them as intermediates outputs so the we can use them in the `denoise` step later. +Once we make sure all the variables we need are available in the block state, we can implement the diff-diff logic inside `__call__`. We created 2 new variables: the change map `diffdiff_mask` and the pre-computed noised latents for all timesteps `original_latents`. @@ -190,161 +207,91 @@ Once we make sure all the variables we need are available in the block state, we -This is the modified `StableDiffusionXLImg2ImgPrepareLatentsStep` we ended up with : +Here are the key changes we made to implement differential diffusion: + +**1. Modified `prepare_latents` step:** ```diff -- class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): -+ class SDXLDiffDiffPrepareLatentsStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return ( -- "Step that prepares the latents for the image-to-image generation process" -+ "Step that prepares the latents for the differential diffusion generation process" - ) - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("vae", AutoencoderKL), - ComponentSpec("scheduler", EulerDiscreteScheduler), -+ ComponentSpec( -+ "mask_processor", -+ VaeImageProcessor, -+ config=FrozenDict({"do_normalize": False, "do_convert_grayscale": True}), -+ default_creation_method="from_config", -+ ) - ] - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ -+ InputParam("diffdiff_map",required=True), - ] - - @property - def intermediate_inputs(self) -> List[InputParam]: - return [ - InputParam("generator"), -- InputParam("latent_timestep", required=True, type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step."), -+ InputParam("timesteps",type_hint=torch.Tensor, description="The timesteps to use for sampling. Can be generated in set_timesteps step."), -+ InputParam("num_inference_steps", type_hint=int, description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step."), - ] - - @property - def intermediate_outputs(self) -> List[OutputParam]: - return [ -+ OutputParam("original_latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"), -+ OutputParam("diffdiff_masks", type_hint=torch.Tensor, description="The masks used for the differential diffusion denoising process"), - ] - - @torch.no_grad() - def __call__(self, components, state: PipelineState): - block_state = self.get_block_state(state) - block_state.dtype = components.vae.dtype - block_state.device = components._execution_device - - block_state.add_noise = True if block_state.denoising_start is None else False -+ components.scheduler.set_begin_index(None) - - if block_state.latents is None: - block_state.latents = prepare_latents_img2img( - components.vae, - components.scheduler, - block_state.image_latents, -- block_state.latent_timestep, -+ block_state.timesteps, - block_state.batch_size, - block_state.num_images_per_prompt, - block_state.dtype, - block_state.device, - block_state.generator, - block_state.add_noise, - ) -+ -+ latent_height = block_state.image_latents.shape[-2] -+ latent_width = block_state.image_latents.shape[-1] -+ diffdiff_map = components.mask_processor.preprocess(block_state.diffdiff_map, height=latent_height, width=latent_width) -+ -+ diffdiff_map = diffdiff_map.squeeze(0).to(block_state.device) -+ thresholds = torch.arange(block_state.num_inference_steps, dtype=diffdiff_map.dtype) / block_state.num_inference_steps -+ thresholds = thresholds.unsqueeze(1).unsqueeze(1).to(block_state.device) -+ block_state.diffdiff_masks = diffdiff_map > (thresholds + (block_state.denoising_start or 0)) -+ block_state.original_latents = block_state.latents - - self.add_block_state(state, block_state) -``` +class SDXLDiffDiffPrepareLatentsStep(PipelineBlock): + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec("scheduler", EulerDiscreteScheduler), ++ ComponentSpec("mask_processor", VaeImageProcessor, config=FrozenDict({"do_normalize": False, "do_convert_grayscale": True})) + ] + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ ++ InputParam("diffdiff_map", required=True), + ] + + @property + def intermediate_inputs(self) -> List[InputParam]: + return [ + InputParam("generator"), +- InputParam("latent_timestep", required=True, type_hint=torch.Tensor), ++ InputParam("timesteps", type_hint=torch.Tensor), ++ InputParam("num_inference_steps", type_hint=int), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ ++ OutputParam("original_latents", type_hint=torch.Tensor), ++ OutputParam("diffdiff_masks", type_hint=torch.Tensor), + ] -Now let's modify `before_denoiser` step, we use diff-diff map to freeze certain regions in the latents before each denoising step. + def __call__(self, components, state: PipelineState): + # ... existing logic ... ++ # Process change map and create masks ++ diffdiff_map = components.mask_processor.preprocess(block_state.diffdiff_map, height=latent_height, width=latent_width) ++ thresholds = torch.arange(block_state.num_inference_steps, dtype=diffdiff_map.dtype) / block_state.num_inference_steps ++ block_state.diffdiff_masks = diffdiff_map > (thresholds + (block_state.denoising_start or 0)) ++ block_state.original_latents = block_state.latents +``` +**2. Modified `before_denoiser` step:** ```diff class SDXLDiffDiffLoopBeforeDenoiser(PipelineBlock): - model_name = "stable-diffusion-xl" - @property def description(self) -> str: return ( -- "step within the denoising loop that prepare the latent input for the denoiser" -+ "Step within the denoising loop for differential diffusion that prepare the latent input for the denoiser" + "Step within the denoising loop for differential diffusion that prepare the latent input for the denoiser" ) -+ @property -+ def inputs(self) -> List[Tuple[str, Any]]: -+ return [ -+ InputParam("denoising_start"), -+ ] + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("denoising_start"), + ] @property def intermediate_inputs(self) -> List[str]: return [ - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." - ), -+ InputParam( -+ "original_latents", -+ type_hint=torch.Tensor, -+ description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." -+ ), -+ InputParam( -+ "diffdiff_masks", -+ type_hint=torch.Tensor, -+ description="The masks used for the differential diffusion denoising process, can be generated in prepare_latent step." -+ ), + InputParam("latents", required=True, type_hint=torch.Tensor), + InputParam("original_latents", type_hint=torch.Tensor), + InputParam("diffdiff_masks", type_hint=torch.Tensor), ] - @torch.no_grad() def __call__(self, components, block_state, i, t): -+ # diff diff -+ if i == 0 and block_state.denoising_start is None: -+ block_state.latents = block_state.original_latents[:1] -+ else: -+ block_state.mask = block_state.diffdiff_masks[i].unsqueeze(0) -+ # cast mask to the same type as latents etc -+ block_state.mask = block_state.mask.to(block_state.latents.dtype) -+ block_state.mask = block_state.mask.unsqueeze(1) # fit shape -+ block_state.latents = block_state.original_latents[i] * block_state.mask + block_state.latents * (1 - block_state.mask) -+ # end diff diff - -+ # expand the latents if we are doing classifier free guidance - block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) - - return components, block_state + # Apply differential diffusion logic + if i == 0 and block_state.denoising_start is None: + block_state.latents = block_state.original_latents[:1] + else: + block_state.mask = block_state.diffdiff_masks[i].unsqueeze(0).unsqueeze(1) + block_state.latents = block_state.original_latents[i] * block_state.mask + block_state.latents * (1 - block_state.mask) + + # ... rest of existing logic ... ``` That's all there is to it! We've just created a simple sequential pipeline by mix-and-match some existing and new pipeline blocks. +Now we use the process we've prepred in step2 to build the pipeline and inspect it. - - -💡 You can inspect the pipeline you built with `print()` - - - -```out +```py +>> dd_pipeline SequentialPipelineBlocks( Class: ModularPipelineBlocks @@ -392,7 +339,7 @@ SequentialPipelineBlocks( [5] prepare_add_cond (StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep) Description: Step that prepares the additional conditioning for the image-to-image/inpainting generation process - [6] denoise (SDXLDiffDiffDenoiseLoop) + [6] denoise (SDXLDiffDiffDenoiseStep) Description: Pipeline block that iteratively denoise the latents over `timesteps`. The specific steps with each iteration can be customized with `sub_blocks` attributes [7] decode (StableDiffusionXLDecodeStep) @@ -401,50 +348,20 @@ SequentialPipelineBlocks( ) ``` -Now if you run the example we prepared earlier, you should see an apple with its right half transformed into a green pear. +Run the example now, you should see an apple with its right half transformed into a green pear. ![Image description](https://cdn-uploads.huggingface.co/production/uploads/624ef9ba9d608e459387b34e/4zqJOz-35Q0i6jyUW3liL.png) ## Adding IP-adapter -We provide an auto IP-adapter block that you can plug-and-play into your modular workflow. It's an `AutoPipelineBlocks`, so it will only run when the user passes an IP adapter image. In this tutorial, we'll focus on how to package it into your differential diffusion workflow. To learn more about `AutoPipelineBlocks`, see [here](TODO) +We provide an auto IP-adapter block that you can plug-and-play into your modular workflow. It's an `AutoPipelineBlocks`, so it will only run when the user passes an IP adapter image. In this tutorial, we'll focus on how to package it into your differential diffusion workflow. To learn more about `AutoPipelineBlocks`, see [here](https://huggingface.co/docs/diffusers/modular_diffusers/write_own_pipeline_block#autopipelineblocks) -Let's create IP-adapter block: +We talked about how to add IP-adapter into your workflow in the [getting-started guide](https://huggingface.co/docs/diffusers/modular_diffusers/quicktour#ip-adapter). Let's just go ahead to create the IP-adapter block. ```py >>> from diffusers.modular_pipelines.stable_diffusion_xl.encoders import StableDiffusionXLAutoIPAdapterStep >>> ip_adapter_block = StableDiffusionXLAutoIPAdapterStep() ->>> print(ip_adapter_block) -``` - -It has 4 components: `unet` and `guider` are already used in diff-diff, but it also has two new ones: `image_encoder` and `feature_extractor` - -```out - ip adapter block: StableDiffusionXLAutoIPAdapterStep( - Class: AutoPipelineBlocks - - ==================================================================================================== - This pipeline contains blocks that are selected at runtime based on inputs. - Trigger Inputs: {'ip_adapter_image'} - Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('ip_adapter_image')`). - ==================================================================================================== - - - Description: Run IP Adapter step if `ip_adapter_image` is provided. - - - Components: - image_encoder (`CLIPVisionModelWithProjection`) - feature_extractor (`CLIPImageProcessor`) - unet (`UNet2DConditionModel`) - guider (`ClassifierFreeGuidance`) - - Blocks: - • ip_adapter [trigger: ip_adapter_image] (StableDiffusionXLIPAdapterStep) - Description: IP Adapter step that handles all the ip adapter related tasks: Load/unload ip adapter weights into unet, prepare ip adapter image embeddings, etc See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin) for more details - -) ``` We can directly add the ip-adapter block instance to the `diffdiff_blocks` that we created before. The `sub_blocks` attribute is a `InsertableDict`, so we're able to insert the it at specific position (index `0` here). @@ -521,7 +438,7 @@ SequentialPipelineBlocks( [6] prepare_add_cond (StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep) Description: Step that prepares the additional conditioning for the image-to-image/inpainting generation process - [7] denoise (SDXLDiffDiffDenoiseLoop) + [7] denoise (SDXLDiffDiffDenoiseStep) Description: Pipeline block that iteratively denoise the latents over `timesteps`. The specific steps with each iteration can be customized with `sub_blocks` attributes [8] decode (StableDiffusionXLDecodeStep) @@ -573,14 +490,14 @@ From looking at the code workflow: differential diffusion only modifies the "bef Intuitively, these two techniques are orthogonal and should combine naturally: differential diffusion controls how much the inference process can deviate from the original in each region, while ControlNet controls in what direction that change occurs. -With this understanding, let's assemble the `SDXLDiffDiffControlNetDenoiseLoop`: +With this understanding, let's assemble the `SDXLDiffDiffControlNetDenoiseStep`: ```py ->>> class SDXLDiffDiffControlNetDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): +>>> class SDXLDiffDiffControlNetDenoiseStep(StableDiffusionXLDenoiseLoopWrapper): ... block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLControlNetLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser] ... block_names = ["before_denoiser", "denoiser", "after_denoiser"] >>> ->>> controlnet_denoise_block = SDXLDiffDiffControlNetDenoiseLoop() +>>> controlnet_denoise_block = SDXLDiffDiffControlNetDenoiseStep() >>> # print(controlnet_denoise) ``` @@ -588,33 +505,32 @@ We provide a auto controlnet input block that you can directly put into your wor ```py ->>> from diffusers.modular_pipelines.stable_diffusion_xl.before_denoise import StableDiffusionXLControlNetAutoInput ->>> control_input_block = StableDiffusionXLControlNetAutoInput() +>>> from diffusers.modular_pipelines.stable_diffusion_xl.modular_blocks_presets import StableDiffusionXLAutoControlNetInputStep +>>> control_input_block = StableDiffusionXLAutoControlNetInputStep() >>> print(control_input_block) ``` ```out -StableDiffusionXLControlNetAutoInput( +StableDiffusionXLAutoControlNetInputStep( Class: AutoPipelineBlocks ==================================================================================================== This pipeline contains blocks that are selected at runtime based on inputs. - Trigger Inputs: {'control_image', 'control_mode'} - Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('control_image')`). + Trigger Inputs: ['control_image', 'control_mode'] ==================================================================================================== Description: Controlnet Input step that prepare the controlnet input. This is an auto pipeline block that works for both controlnet and controlnet_union. - - `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided. - - `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided. + (it should be called right before the denoise step) - `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided. + - `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided. - if neither `control_mode` nor `control_image` is provided, step will be skipped. Components: controlnet (`ControlNetUnionModel`) control_image_processor (`VaeImageProcessor`) - Blocks: + Sub-Blocks: • controlnet_union [trigger: control_mode] (StableDiffusionXLControlNetUnionInputStep) Description: step that prepares inputs for the ControlNetUnion model @@ -622,6 +538,7 @@ StableDiffusionXLControlNetAutoInput( Description: step that prepare inputs for controlnet ) + ``` Let's assemble the blocks and run an example using controlnet + differential diffusion. We used a tomato as `control_image`, so you can see that in the output, the right half that transformed into a pear had a tomato-like shape. @@ -655,12 +572,12 @@ Let's assemble the blocks and run an example using controlnet + differential dif ... )[0] ``` -Optionally, We can combine `SDXLDiffDiffControlNetDenoiseLoop` and `SDXLDiffDiffDenoiseLoop` into a `AutoPipelineBlocks` so that same workflow can work with or without controlnet. +Optionally, We can combine `SDXLDiffDiffControlNetDenoiseStep` and `SDXLDiffDiffDenoiseStep` into a `AutoPipelineBlocks` so that same workflow can work with or without controlnet. ```py >>> class SDXLDiffDiffAutoDenoiseStep(AutoPipelineBlocks): -... block_classes = [SDXLDiffDiffControlNetDenoiseLoop, SDXLDiffDiffDenoiseLoop] +... block_classes = [SDXLDiffDiffControlNetDenoiseStep, SDXLDiffDiffDenoiseStep] ... block_names = ["controlnet_denoise", "denoise"] ... block_trigger_inputs = ["controlnet_cond", None] ``` @@ -669,7 +586,7 @@ Optionally, We can combine `SDXLDiffDiffControlNetDenoiseLoop` and `SDXLDiffDiff - Note that it's perfectly fine not to use `AutoPipelineBlocks`. In fact, we recommend only using `AutoPipelineBlocks` to package your workflow at the end once you've verified all your pipelines work as expected. We won't go into too much detail about `AutoPipelineBlocks` in this section, but you can read more about it [here](TODO). + Note that it's perfectly fine not to use `AutoPipelineBlocks`. In fact, we recommend only using `AutoPipelineBlocks` to package your workflow at the end once you've verified all your pipelines work as expected. From 4543d216ec8ad1a79dee337d1c3f40d4dea42389 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 1 Jul 2025 03:06:13 +0200 Subject: [PATCH 135/170] rename quick start- it is really not quick --- .../en/modular_diffusers/{quicktour.md => getting_started.md} | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename docs/source/en/modular_diffusers/{quicktour.md => getting_started.md} (99%) diff --git a/docs/source/en/modular_diffusers/quicktour.md b/docs/source/en/modular_diffusers/getting_started.md similarity index 99% rename from docs/source/en/modular_diffusers/quicktour.md rename to docs/source/en/modular_diffusers/getting_started.md index 211fe738a1dc..3601b9069d62 100644 --- a/docs/source/en/modular_diffusers/quicktour.md +++ b/docs/source/en/modular_diffusers/getting_started.md @@ -10,7 +10,7 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> -# Getting Started with Modular Diffusers +# Getting Started with Modular Diffusers: An Comprehensive Overview With Modular Diffusers, we introduce a unified pipeline system that simplifies how you work with diffusion models. Instead of creating separate pipelines for each task, Modular Diffusers let you: From 1987c0789967e82a5431fd46ecd7c4bc53b9ee89 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 1 Jul 2025 03:06:34 +0200 Subject: [PATCH 136/170] update docstree --- docs/source/en/_toctree.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index d3aa0be3331f..bc6420f665ae 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -92,12 +92,12 @@ title: API Reference title: Hybrid Inference - sections: - - local: modular_diffusers/quicktour - title: Quicktour + - local: modular_diffusers/getting_started + title: Getting Started - local: modular_diffusers/write_own_pipeline_block title: Write your own pipeline block - - local: modular_diffusers/developer_guide - title: Developer Guide + - local: modular_diffusers/end_to_end_guide + title: End-to-End Developer Guide title: Modular Diffusers - sections: - local: using-diffusers/consisid From 2e2024152c93aa920eff872414e1f246fb8c23b4 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 1 Jul 2025 03:07:08 +0200 Subject: [PATCH 137/170] up up --- docs/source/en/modular_diffusers/write_own_pipeline_block.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/modular_diffusers/write_own_pipeline_block.md b/docs/source/en/modular_diffusers/write_own_pipeline_block.md index a935380712ff..ac01b374cd44 100644 --- a/docs/source/en/modular_diffusers/write_own_pipeline_block.md +++ b/docs/source/en/modular_diffusers/write_own_pipeline_block.md @@ -10,7 +10,7 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> -# `ModularPipelineBlocks` +# Writing Your Own Pipeline Blocks In Modular Diffusers, you build your workflow using `ModularPipelineBlocks`. We support 4 different types of blocks: `PipelineBlock`, `SequentialPipelineBlocks`, `LoopSequentialPipelineBlocks`, and `AutoPipelineBlocks`. Among them, `PipelineBlock` is the most fundamental building block of the whole system - it's like a brick in a Lego system. These blocks are designed to easily connect with each other, allowing for modular construction of creative and potentially very complex workflows. @@ -132,7 +132,7 @@ expected_config = [ ] ``` -**Components**: You must provide a `name` and ideally a `type_hint`. The actual loading details (`repo`, `subfolder`, `variant` and `revision` fields) are typically specified when creating the pipeline, as we covered in the [quicktour](quicktour.md#loading-components-into-a-modularpipeline). +**Components**: In the `ComponentSpec`, You must provide a `name` and ideally a `type_hint`. The actual loading details (`repo`, `subfolder`, `variant` and `revision` fields) are typically specified when creating the pipeline, as we covered in the [quicktour](quicktour.md#loading-components-into-a-modularpipeline). **Configs**: Simple pipeline-level settings that control behavior across all blocks. From 13fe24815225233c636670ae81a3af75a7c5d8ae Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 1 Jul 2025 03:22:15 +0200 Subject: [PATCH 138/170] add modularpipelineblocks to be pushtohub mixin --- src/diffusers/modular_pipelines/modular_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 1ada2fc41473..5cc27d5586b8 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -246,7 +246,7 @@ def format_value(v): return f"BlockState(\n{attributes}\n)" -class ModularPipelineBlocks(ConfigMixin): +class ModularPipelineBlocks(ConfigMixin, PushToHubMixin): """ Base class for all Pipeline Blocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks, LoopSequentialPipelineBlocks From 8cb5b084b591777e1f42afe98005ac706dc39062 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 1 Jul 2025 03:22:27 +0200 Subject: [PATCH 139/170] up upup --- docs/source/en/modular_diffusers/end_to_end_guide.md | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/docs/source/en/modular_diffusers/end_to_end_guide.md b/docs/source/en/modular_diffusers/end_to_end_guide.md index 0784ce2e1fbd..ab4ba8020da3 100644 --- a/docs/source/en/modular_diffusers/end_to_end_guide.md +++ b/docs/source/en/modular_diffusers/end_to_end_guide.md @@ -613,9 +613,13 @@ to use You can easily share your differential diffusion workflow on the hub, by creating a modular repo like this https://huggingface.co/YiYiXu/modular-diffdiff -[YiYi TODO: add details tutorial on how to create the modular repo, building upon this https://github.com/huggingface/diffusers/pull/11462] +To create a Modular Repo and share on hub, you just need to run. Note that if your pipeline contains custom block, you need to manually upload the code to the hub. But we are working on a command line tool to help you upload it very easily. -With a modular repo, it is very easy for the community to use the workflow you just created! +```py +dd_pipeline.save_pretrained("YiYiXu/test_modular_doc", push_to_hub=True) +``` + +With a modular repo, it is very easy for the community to use the workflow you just created! Here is an example to use the differential-diffusion pipeline we just created and shared. ```py >>> from diffusers.modular_pipelines import ModularPipeline, ComponentsManager From 3e46c86a9357259fb9cc1ca709be323e86f06147 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 1 Jul 2025 04:51:49 +0200 Subject: [PATCH 140/170] fix links in the doc --- .../en/modular_diffusers/getting_started.md | 16 +++++++--------- .../write_own_pipeline_block.md | 4 ++-- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/docs/source/en/modular_diffusers/getting_started.md b/docs/source/en/modular_diffusers/getting_started.md index 3601b9069d62..4a527b69a165 100644 --- a/docs/source/en/modular_diffusers/getting_started.md +++ b/docs/source/en/modular_diffusers/getting_started.md @@ -18,7 +18,7 @@ With Modular Diffusers, we introduce a unified pipeline system that simplifies h **Assemble Like LEGO®**: You can mix and match blocks in flexible ways. This allows you to write dedicated blocks for specific workflows, and then assemble different blocks into a pipeline that that can be used more conveniently for multiple workflows. -In this guide, we will focus on how to build pipelines this way using blocks we officially support at diffusers 🧨! We will show you how to write your own pipeline blocks and go into more details on how they work under the hood in this [guide](TODO). For advanced users who want to build complete workflows from scratch, we provide an end-to-end example in the [Developer Guide](developer_guide.md) that covers everything from writing custom pipeline blocks to deploying your workflow as a UI node. +In this guide, we will focus on how to build pipelines this way using blocks we officially support at diffusers 🧨! We will show you how to write your own pipeline blocks and go into more details on how they work under the hood in this [guide](./write_own_pipeline_block.md). For advanced users who want to build complete workflows from scratch, we provide an end-to-end example in the [Developer Guide](./end_to_end.md) that covers everything from writing custom pipeline blocks to deploying your workflow as a UI node. Let's get started! The Modular Diffusers Framework consists of three main components: - ModularPipelineBlocks @@ -29,10 +29,10 @@ Let's get started! The Modular Diffusers Framework consists of three main compon Pipeline blocks are the fundamental building blocks of the Modular Diffusers system. All pipeline blocks inherit from the base class `ModularPipelineBlocks`, including: -- [`PipelineBlock`](TODO): The most granular block - you define the computation logic. -- [`SequentialPipelineBlocks`](TODO): A multi-block composed of multiple blocks that run sequentially, passing outputs as inputs to the next block. -- [`LoopSequentialPipelineBlocks`](TODO): A special type of multi-block that forms loops. -- [`AutoPipelineBlocks`](TODO): A multi-block composed of multiple blocks that are selected at runtime based on the inputs. +- [`PipelineBlock`]: The most granular block - you define the computation logic. +- [`SequentialPipelineBlocks`]: A multi-block composed of multiple blocks that run sequentially, passing outputs as inputs to the next block. +- [`LoopSequentialPipelineBlocks`]: A special type of multi-block that forms loops. +- [`AutoPipelineBlocks`]: A multi-block composed of multiple blocks that are selected at runtime based on the inputs. All blocks have a consistent interface defining their requirements (components, configs, inputs, outputs) and computation logic. They can be used standalone or combined into larger blocks. Blocks are designed to be assembled into workflows for tasks such as image generation, video creation, and inpainting. @@ -288,7 +288,7 @@ ALL_BLOCKS = { -We will not go over how to write your own ModularPipelineBlocks but you can learn more about it [here](TODO). +We will not go over how to write your own ModularPipelineBlocks but you can learn more about it [here](./write_own_pipeline_block.md). This covers the essentials of pipeline blocks! You may have noticed that we haven't discussed how to load or run pipeline blocks - that's because **pipeline blocks are not runnable by themselves**. They are essentially **"definitions"** - they define the specifications and computational steps for a pipeline, but they do not contain any model states. To actually run them, you need to convert them into a `ModularPipeline` object. @@ -848,7 +848,7 @@ StableDiffusionXLAutoControlnetStep( -💡 **Auto Blocks**: This is first time we meet a Auto Blocks! `AutoPipelineBlocks` automatically adapt to your inputs by combining multiple workflows with conditional logic. This is why one convenient block can work for all tasks and controlnet types. See the [Auto Blocks Guide](TODO) for more details. +💡 **Auto Blocks**: This is first time we meet a Auto Blocks! `AutoPipelineBlocks` automatically adapt to your inputs by combining multiple workflows with conditional logic. This is why one convenient block can work for all tasks and controlnet types. See the [Auto Blocks Guide](https://huggingface.co/docs/diffusers/modular_diffusers/write_own_pipeline_block#autopipelineblocks) for more details. @@ -1029,8 +1029,6 @@ Since we have a modular setup where different pipelines may share components, we - `blocks.init_pipeline(repo)` creates a pipeline with a built-in loader that only includes components its blocks needs - `StableDiffusionXLModularLoader.from_pretrained(repo)` set up a standalone loader that includes everything in the repo's `modular_model_index.json` -See the [Loader Guide](TODO) for more details. - ```py diff --git a/docs/source/en/modular_diffusers/write_own_pipeline_block.md b/docs/source/en/modular_diffusers/write_own_pipeline_block.md index ac01b374cd44..4739bbc6901a 100644 --- a/docs/source/en/modular_diffusers/write_own_pipeline_block.md +++ b/docs/source/en/modular_diffusers/write_own_pipeline_block.md @@ -132,7 +132,7 @@ expected_config = [ ] ``` -**Components**: In the `ComponentSpec`, You must provide a `name` and ideally a `type_hint`. The actual loading details (`repo`, `subfolder`, `variant` and `revision` fields) are typically specified when creating the pipeline, as we covered in the [quicktour](quicktour.md#loading-components-into-a-modularpipeline). +**Components**: In the `ComponentSpec`, You must provide a `name` and ideally a `type_hint`. The actual loading details (`repo`, `subfolder`, `variant` and `revision` fields) are typically specified when creating the pipeline, as we covered in the [Getting Started Guide](https://huggingface.co/docs/diffusers/en/modular_diffusers/getting_started#loading-components-into-a-modularpipeline). **Configs**: Simple pipeline-level settings that control behavior across all blocks. @@ -292,7 +292,7 @@ I hope by now you have a basic idea about how `PipelineBlock` manages state thro ## Create a `SequentialPipelineBlocks` -I assume that you're already familiar with `SequentialPipelineBlocks` and how to create them with the `from_blocks_dict` API. It's one of the most common ways to use Modular Diffusers, and we've covered it pretty well in the [Getting Started Guide](https://moon-ci-docs.huggingface.co/docs/diffusers/pr_9672/en/modular_diffusers/quicktour#modularpipelineblocks). +I assume that you're already familiar with `SequentialPipelineBlocks` and how to create them with the `from_blocks_dict` API. It's one of the most common ways to use Modular Diffusers, and we've covered it pretty well in the [Getting Started Guide](https://huggingface.co/docs/diffusers/pr_9672/en/modular_diffusers/getting_started#modularpipelineblocks). But how do blocks actually connect and work together? Understanding this is crucial for building effective modular workflows. Let's explore this through an example. From 13c51bb038768cbf7b39b776410c5ca2cae5f554 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 5 Jul 2025 03:49:10 +0530 Subject: [PATCH 141/170] Modular PAG Guider (#11860) * update * fix * update --- src/diffusers/__init__.py | 2 + src/diffusers/guiders/__init__.py | 2 + .../guiders/perturbed_attention_guidance.py | 114 ++++++++++++++++++ 3 files changed, 118 insertions(+) create mode 100644 src/diffusers/guiders/perturbed_attention_guidance.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 2c427847948f..314a4126d2bd 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -138,6 +138,7 @@ "AutoGuidance", "ClassifierFreeGuidance", "ClassifierFreeZeroStarGuidance", + "PerturbedAttentionGuidance", "SkipLayerGuidance", "SmoothedEnergyGuidance", "TangentialClassifierFreeGuidance", @@ -785,6 +786,7 @@ AutoGuidance, ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, + PerturbedAttentionGuidance, SkipLayerGuidance, SmoothedEnergyGuidance, TangentialClassifierFreeGuidance, diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py index 37e0fa400360..1c288f00f084 100644 --- a/src/diffusers/guiders/__init__.py +++ b/src/diffusers/guiders/__init__.py @@ -22,6 +22,7 @@ from .auto_guidance import AutoGuidance from .classifier_free_guidance import ClassifierFreeGuidance from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance + from .perturbed_attention_guidance import PerturbedAttentionGuidance from .skip_layer_guidance import SkipLayerGuidance from .smoothed_energy_guidance import SmoothedEnergyGuidance from .tangential_classifier_free_guidance import TangentialClassifierFreeGuidance @@ -31,6 +32,7 @@ AutoGuidance, ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, + PerturbedAttentionGuidance, SkipLayerGuidance, SmoothedEnergyGuidance, TangentialClassifierFreeGuidance, diff --git a/src/diffusers/guiders/perturbed_attention_guidance.py b/src/diffusers/guiders/perturbed_attention_guidance.py new file mode 100644 index 000000000000..cd75348ef5bd --- /dev/null +++ b/src/diffusers/guiders/perturbed_attention_guidance.py @@ -0,0 +1,114 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union + +from ..hooks import LayerSkipConfig +from .skip_layer_guidance import SkipLayerGuidance + + +class PerturbedAttentionGuidance(SkipLayerGuidance): + """ + Perturbed Attention Guidance (PAG): https://huggingface.co/papers/2403.17377 + + The intution behind PAG can be thought of as moving the CFG predicted distribution estimates further away from + worse versions of the conditional distribution estimates. PAG was one of the first techniques to introduce the idea + of using a worse version of the trained model for better guiding itself in the denoising process. It perturbs the + attention scores of the latent stream by replacing the score matrix with an identity matrix for selectively chosen + layers. + + Additional reading: + - [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507) + + PAG is implemented as a specialization of the SkipLayerGuidance due to similarities in the configuration parameters + and implementation details. + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + perturbed_guidance_scale (`float`, defaults to `2.8`): + The scale parameter for perturbed attention guidance. + perturbed_guidance_start (`float`, defaults to `0.01`): + The fraction of the total number of denoising steps after which perturbed attention guidance starts. + perturbed_guidance_stop (`float`, defaults to `0.2`): + The fraction of the total number of denoising steps after which perturbed attention guidance stops. + perturbed_guidance_layers (`int` or `List[int]`, *optional*): + The layer indices to apply perturbed attention guidance to. Can be a single integer or a list of integers. + If not provided, `skip_layer_config` must be provided. + skip_layer_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*): + The configuration for the perturbed attention guidance. Can be a single `LayerSkipConfig` or a list of + `LayerSkipConfig`. If not provided, `perturbed_guidance_layers` must be provided. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.01`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `0.2`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + # NOTE: The current implementation does not account for joint latent conditioning (text + image/video tokens in + # the same latent stream). It assumes the entire latent is a single stream of visual tokens. It would be very + # complex to support joint latent conditioning in a model-agnostic manner without specializing the implementation + # for each model architecture. + + def __init__( + self, + guidance_scale: float = 7.5, + perturbed_guidance_scale: float = 2.8, + perturbed_guidance_start: float = 0.01, + perturbed_guidance_stop: float = 0.2, + perturbed_guidance_layers: Optional[Union[int, List[int]]] = None, + skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + if skip_layer_config is None: + if perturbed_guidance_layers is None: + raise ValueError( + "`perturbed_guidance_layers` must be provided if `skip_layer_config` is not specified." + ) + skip_layer_config = LayerSkipConfig( + indices=perturbed_guidance_layers, + skip_attention=False, + skip_attention_scores=True, + skip_ff=False, + ) + else: + if perturbed_guidance_layers is not None: + raise ValueError( + "`perturbed_guidance_layers` should not be provided if `skip_layer_config` is specified." + ) + + super().__init__( + guidance_scale=guidance_scale, + skip_layer_guidance_scale=perturbed_guidance_scale, + skip_layer_guidance_start=perturbed_guidance_start, + skip_layer_guidance_stop=perturbed_guidance_stop, + skip_layer_guidance_layers=perturbed_guidance_layers, + skip_layer_config=skip_layer_config, + guidance_rescale=guidance_rescale, + use_original_formulation=use_original_formulation, + start=start, + stop=stop, + ) From b750c69859e6f9c0ea85c82fbaf6534b1f3a84f9 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 5 Jul 2025 08:38:05 +0530 Subject: [PATCH 142/170] Modular Guider ConfigMixin (#11862) * update * update * register to config pag --- src/diffusers/guiders/adaptive_projected_guidance.py | 2 ++ src/diffusers/guiders/auto_guidance.py | 2 ++ src/diffusers/guiders/classifier_free_guidance.py | 2 ++ .../guiders/classifier_free_zero_star_guidance.py | 2 ++ src/diffusers/guiders/entropy_rectifying_guidance.py | 0 src/diffusers/guiders/guider_utils.py | 7 ++++++- src/diffusers/guiders/perturbed_attention_guidance.py | 2 ++ src/diffusers/guiders/skip_layer_guidance.py | 2 ++ src/diffusers/guiders/smoothed_energy_guidance.py | 2 ++ .../guiders/tangential_classifier_free_guidance.py | 2 ++ src/diffusers/modular_pipelines/modular_pipeline.py | 2 +- 11 files changed, 23 insertions(+), 2 deletions(-) delete mode 100644 src/diffusers/guiders/entropy_rectifying_guidance.py diff --git a/src/diffusers/guiders/adaptive_projected_guidance.py b/src/diffusers/guiders/adaptive_projected_guidance.py index 10d05258bc3f..81137db106a0 100644 --- a/src/diffusers/guiders/adaptive_projected_guidance.py +++ b/src/diffusers/guiders/adaptive_projected_guidance.py @@ -17,6 +17,7 @@ import torch +from ..configuration_utils import register_to_config from .guider_utils import BaseGuidance, rescale_noise_cfg @@ -53,6 +54,7 @@ class AdaptiveProjectedGuidance(BaseGuidance): _input_predictions = ["pred_cond", "pred_uncond"] + @register_to_config def __init__( self, guidance_scale: float = 7.5, diff --git a/src/diffusers/guiders/auto_guidance.py b/src/diffusers/guiders/auto_guidance.py index dc1bf26ade39..159354559966 100644 --- a/src/diffusers/guiders/auto_guidance.py +++ b/src/diffusers/guiders/auto_guidance.py @@ -17,6 +17,7 @@ import torch +from ..configuration_utils import register_to_config from ..hooks import HookRegistry, LayerSkipConfig from ..hooks.layer_skip import _apply_layer_skip_hook from .guider_utils import BaseGuidance, rescale_noise_cfg @@ -60,6 +61,7 @@ class AutoGuidance(BaseGuidance): _input_predictions = ["pred_cond", "pred_uncond"] + @register_to_config def __init__( self, guidance_scale: float = 7.5, diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py index cc3ffd5758b8..7e72b92fcee2 100644 --- a/src/diffusers/guiders/classifier_free_guidance.py +++ b/src/diffusers/guiders/classifier_free_guidance.py @@ -17,6 +17,7 @@ import torch +from ..configuration_utils import register_to_config from .guider_utils import BaseGuidance, rescale_noise_cfg @@ -67,6 +68,7 @@ class ClassifierFreeGuidance(BaseGuidance): _input_predictions = ["pred_cond", "pred_uncond"] + @register_to_config def __init__( self, guidance_scale: float = 7.5, diff --git a/src/diffusers/guiders/classifier_free_zero_star_guidance.py b/src/diffusers/guiders/classifier_free_zero_star_guidance.py index ea4c4c197f7a..85d5cc62d4e7 100644 --- a/src/diffusers/guiders/classifier_free_zero_star_guidance.py +++ b/src/diffusers/guiders/classifier_free_zero_star_guidance.py @@ -17,6 +17,7 @@ import torch +from ..configuration_utils import register_to_config from .guider_utils import BaseGuidance, rescale_noise_cfg @@ -58,6 +59,7 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance): _input_predictions = ["pred_cond", "pred_uncond"] + @register_to_config def __init__( self, guidance_scale: float = 7.5, diff --git a/src/diffusers/guiders/entropy_rectifying_guidance.py b/src/diffusers/guiders/entropy_rectifying_guidance.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index 1c7d6de796b5..555f8897c089 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -16,6 +16,7 @@ import torch +from ..configuration_utils import ConfigMixin from ..utils import get_logger @@ -23,12 +24,16 @@ from ..modular_pipelines.modular_pipeline import BlockState +GUIDER_CONFIG_NAME = "guider_config.json" + + logger = get_logger(__name__) # pylint: disable=invalid-name -class BaseGuidance: +class BaseGuidance(ConfigMixin): r"""Base class providing the skeleton for implementing guidance techniques.""" + config_name = GUIDER_CONFIG_NAME _input_predictions = None _identifier_key = "__guidance_identifier__" diff --git a/src/diffusers/guiders/perturbed_attention_guidance.py b/src/diffusers/guiders/perturbed_attention_guidance.py index cd75348ef5bd..dbba904d0bde 100644 --- a/src/diffusers/guiders/perturbed_attention_guidance.py +++ b/src/diffusers/guiders/perturbed_attention_guidance.py @@ -14,6 +14,7 @@ from typing import List, Optional, Union +from ..configuration_utils import register_to_config from ..hooks import LayerSkipConfig from .skip_layer_guidance import SkipLayerGuidance @@ -70,6 +71,7 @@ class PerturbedAttentionGuidance(SkipLayerGuidance): # complex to support joint latent conditioning in a model-agnostic manner without specializing the implementation # for each model architecture. + @register_to_config def __init__( self, guidance_scale: float = 7.5, diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py index 4b8e556fbe96..e67b20df19fa 100644 --- a/src/diffusers/guiders/skip_layer_guidance.py +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -17,6 +17,7 @@ import torch +from ..configuration_utils import register_to_config from ..hooks import HookRegistry, LayerSkipConfig from ..hooks.layer_skip import _apply_layer_skip_hook from .guider_utils import BaseGuidance, rescale_noise_cfg @@ -86,6 +87,7 @@ class SkipLayerGuidance(BaseGuidance): _input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"] + @register_to_config def __init__( self, guidance_scale: float = 7.5, diff --git a/src/diffusers/guiders/smoothed_energy_guidance.py b/src/diffusers/guiders/smoothed_energy_guidance.py index 5b3e2d6c6390..66c46064d46d 100644 --- a/src/diffusers/guiders/smoothed_energy_guidance.py +++ b/src/diffusers/guiders/smoothed_energy_guidance.py @@ -17,6 +17,7 @@ import torch +from ..configuration_utils import register_to_config from ..hooks import HookRegistry from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook from .guider_utils import BaseGuidance, rescale_noise_cfg @@ -76,6 +77,7 @@ class SmoothedEnergyGuidance(BaseGuidance): _input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"] + @register_to_config def __init__( self, guidance_scale: float = 7.5, diff --git a/src/diffusers/guiders/tangential_classifier_free_guidance.py b/src/diffusers/guiders/tangential_classifier_free_guidance.py index 28f55880aab5..b3187e526316 100644 --- a/src/diffusers/guiders/tangential_classifier_free_guidance.py +++ b/src/diffusers/guiders/tangential_classifier_free_guidance.py @@ -17,6 +17,7 @@ import torch +from ..configuration_utils import register_to_config from .guider_utils import BaseGuidance, rescale_noise_cfg @@ -49,6 +50,7 @@ class TangentialClassifierFreeGuidance(BaseGuidance): _input_predictions = ["pred_cond", "pred_uncond"] + @register_to_config def __init__( self, guidance_scale: float = 7.5, diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 5cc27d5586b8..99db80d3151c 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -1911,7 +1911,7 @@ def update(self, **kwargs): loader.update(unet=new_unet_model, text_encoder=new_text_encoder) # Update configuration values - loader.update(requires_safety_checker=False, guidance_rescale=0.7) + loader.update(requires_safety_checker=False) # Update both components and configs together loader.update(unet=new_unet_model, requires_safety_checker=False) From 284f827d6c0badb74626ee5a573e039d6cf32ac8 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 5 Jul 2025 23:19:35 +0530 Subject: [PATCH 143/170] Modular custom config object serialization (#11868) * update * make style --- src/diffusers/configuration_utils.py | 4 ++ src/diffusers/guiders/auto_guidance.py | 9 +++- .../guiders/perturbed_attention_guidance.py | 46 +++++++++++++++---- src/diffusers/guiders/skip_layer_guidance.py | 9 +++- .../guiders/smoothed_energy_guidance.py | 5 ++ src/diffusers/hooks/layer_skip.py | 9 +++- .../hooks/smoothed_energy_guidance_utils.py | 9 +++- 7 files changed, 76 insertions(+), 15 deletions(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index f9b652bbc021..770c949ffb3d 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -601,6 +601,10 @@ def to_json_saveable(value): value = value.tolist() elif isinstance(value, Path): value = value.as_posix() + elif hasattr(value, "to_dict") and callable(value.to_dict): + value = value.to_dict() + elif isinstance(value, list): + value = [to_json_saveable(v) for v in value] return value if "quantization_config" in config_dict: diff --git a/src/diffusers/guiders/auto_guidance.py b/src/diffusers/guiders/auto_guidance.py index 159354559966..e1642211d393 100644 --- a/src/diffusers/guiders/auto_guidance.py +++ b/src/diffusers/guiders/auto_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import torch @@ -66,7 +66,7 @@ def __init__( self, guidance_scale: float = 7.5, auto_guidance_layers: Optional[Union[int, List[int]]] = None, - auto_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None, + auto_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None, dropout: Optional[float] = None, guidance_rescale: float = 0.0, use_original_formulation: bool = False, @@ -104,6 +104,9 @@ def __init__( LayerSkipConfig(layer, fqn="auto", dropout=dropout) for layer in auto_guidance_layers ] + if isinstance(auto_guidance_config, dict): + auto_guidance_config = LayerSkipConfig.from_dict(auto_guidance_config) + if isinstance(auto_guidance_config, LayerSkipConfig): auto_guidance_config = [auto_guidance_config] @@ -111,6 +114,8 @@ def __init__( raise ValueError( f"Expected `auto_guidance_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(auto_guidance_config)}." ) + elif isinstance(next(iter(auto_guidance_config), None), dict): + auto_guidance_config = [LayerSkipConfig.from_dict(config) for config in auto_guidance_config] self.auto_guidance_config = auto_guidance_config self._auto_guidance_hook_names = [f"AutoGuidance_{i}" for i in range(len(self.auto_guidance_config))] diff --git a/src/diffusers/guiders/perturbed_attention_guidance.py b/src/diffusers/guiders/perturbed_attention_guidance.py index dbba904d0bde..3045f2feaae2 100644 --- a/src/diffusers/guiders/perturbed_attention_guidance.py +++ b/src/diffusers/guiders/perturbed_attention_guidance.py @@ -12,13 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Union +from typing import Any, Dict, List, Optional, Union from ..configuration_utils import register_to_config from ..hooks import LayerSkipConfig +from ..utils import get_logger from .skip_layer_guidance import SkipLayerGuidance +logger = get_logger(__name__) # pylint: disable=invalid-name + + class PerturbedAttentionGuidance(SkipLayerGuidance): """ Perturbed Attention Guidance (PAG): https://huggingface.co/papers/2403.17377 @@ -48,8 +52,8 @@ class PerturbedAttentionGuidance(SkipLayerGuidance): The fraction of the total number of denoising steps after which perturbed attention guidance stops. perturbed_guidance_layers (`int` or `List[int]`, *optional*): The layer indices to apply perturbed attention guidance to. Can be a single integer or a list of integers. - If not provided, `skip_layer_config` must be provided. - skip_layer_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*): + If not provided, `perturbed_guidance_config` must be provided. + perturbed_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*): The configuration for the perturbed attention guidance. Can be a single `LayerSkipConfig` or a list of `LayerSkipConfig`. If not provided, `perturbed_guidance_layers` must be provided. guidance_rescale (`float`, defaults to `0.0`): @@ -79,19 +83,20 @@ def __init__( perturbed_guidance_start: float = 0.01, perturbed_guidance_stop: float = 0.2, perturbed_guidance_layers: Optional[Union[int, List[int]]] = None, - skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None, + perturbed_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None, guidance_rescale: float = 0.0, use_original_formulation: bool = False, start: float = 0.0, stop: float = 1.0, ): - if skip_layer_config is None: + if perturbed_guidance_config is None: if perturbed_guidance_layers is None: raise ValueError( - "`perturbed_guidance_layers` must be provided if `skip_layer_config` is not specified." + "`perturbed_guidance_layers` must be provided if `perturbed_guidance_config` is not specified." ) - skip_layer_config = LayerSkipConfig( + perturbed_guidance_config = LayerSkipConfig( indices=perturbed_guidance_layers, + fqn="auto", skip_attention=False, skip_attention_scores=True, skip_ff=False, @@ -99,8 +104,31 @@ def __init__( else: if perturbed_guidance_layers is not None: raise ValueError( - "`perturbed_guidance_layers` should not be provided if `skip_layer_config` is specified." + "`perturbed_guidance_layers` should not be provided if `perturbed_guidance_config` is specified." + ) + + if isinstance(perturbed_guidance_config, dict): + perturbed_guidance_config = LayerSkipConfig.from_dict(perturbed_guidance_config) + + if isinstance(perturbed_guidance_config, LayerSkipConfig): + perturbed_guidance_config = [perturbed_guidance_config] + + if not isinstance(perturbed_guidance_config, list): + raise ValueError( + "`perturbed_guidance_config` must be a `LayerSkipConfig`, a list of `LayerSkipConfig`, or a dict that can be converted to a `LayerSkipConfig`." + ) + elif isinstance(next(iter(perturbed_guidance_config), None), dict): + perturbed_guidance_config = [LayerSkipConfig.from_dict(config) for config in perturbed_guidance_config] + + for config in perturbed_guidance_config: + if config.skip_attention or not config.skip_attention_scores or config.skip_ff: + logger.warning( + "Perturbed Attention Guidance is designed to perturb attention scores, so `skip_attention` should be False, `skip_attention_scores` should be True, and `skip_ff` should be False. " + "Please check your configuration. Modifying the config to match the expected values." ) + config.skip_attention = False + config.skip_attention_scores = True + config.skip_ff = False super().__init__( guidance_scale=guidance_scale, @@ -108,7 +136,7 @@ def __init__( skip_layer_guidance_start=perturbed_guidance_start, skip_layer_guidance_stop=perturbed_guidance_stop, skip_layer_guidance_layers=perturbed_guidance_layers, - skip_layer_config=skip_layer_config, + skip_layer_config=perturbed_guidance_config, guidance_rescale=guidance_rescale, use_original_formulation=use_original_formulation, start=start, diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py index e67b20df19fa..68a657960a45 100644 --- a/src/diffusers/guiders/skip_layer_guidance.py +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import torch @@ -95,7 +95,7 @@ def __init__( skip_layer_guidance_start: float = 0.01, skip_layer_guidance_stop: float = 0.2, skip_layer_guidance_layers: Optional[Union[int, List[int]]] = None, - skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None, + skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None, guidance_rescale: float = 0.0, use_original_formulation: bool = False, start: float = 0.0, @@ -135,6 +135,9 @@ def __init__( ) skip_layer_config = [LayerSkipConfig(layer, fqn="auto") for layer in skip_layer_guidance_layers] + if isinstance(skip_layer_config, dict): + skip_layer_config = LayerSkipConfig.from_dict(skip_layer_config) + if isinstance(skip_layer_config, LayerSkipConfig): skip_layer_config = [skip_layer_config] @@ -142,6 +145,8 @@ def __init__( raise ValueError( f"Expected `skip_layer_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(skip_layer_config)}." ) + elif isinstance(next(iter(skip_layer_config), None), dict): + skip_layer_config = [LayerSkipConfig.from_dict(config) for config in skip_layer_config] self.skip_layer_config = skip_layer_config self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))] diff --git a/src/diffusers/guiders/smoothed_energy_guidance.py b/src/diffusers/guiders/smoothed_energy_guidance.py index 66c46064d46d..d8e8a3cf2fa8 100644 --- a/src/diffusers/guiders/smoothed_energy_guidance.py +++ b/src/diffusers/guiders/smoothed_energy_guidance.py @@ -125,6 +125,9 @@ def __init__( ) seg_guidance_config = [SmoothedEnergyGuidanceConfig(layer, fqn="auto") for layer in seg_guidance_layers] + if isinstance(seg_guidance_config, dict): + seg_guidance_config = SmoothedEnergyGuidanceConfig.from_dict(seg_guidance_config) + if isinstance(seg_guidance_config, SmoothedEnergyGuidanceConfig): seg_guidance_config = [seg_guidance_config] @@ -132,6 +135,8 @@ def __init__( raise ValueError( f"Expected `seg_guidance_config` to be a SmoothedEnergyGuidanceConfig or a list of SmoothedEnergyGuidanceConfig, but got {type(seg_guidance_config)}." ) + elif isinstance(next(iter(seg_guidance_config), None), dict): + seg_guidance_config = [SmoothedEnergyGuidanceConfig.from_dict(config) for config in seg_guidance_config] self.seg_guidance_config = seg_guidance_config self._seg_layer_hook_names = [f"SmoothedEnergyGuidance_{i}" for i in range(len(self.seg_guidance_config))] diff --git a/src/diffusers/hooks/layer_skip.py b/src/diffusers/hooks/layer_skip.py index 32c9f205d683..487a1876d605 100644 --- a/src/diffusers/hooks/layer_skip.py +++ b/src/diffusers/hooks/layer_skip.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from dataclasses import dataclass +from dataclasses import asdict, dataclass from typing import Callable, List, Optional import torch @@ -78,6 +78,13 @@ def __post_init__(self): "Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0." ) + def to_dict(self): + return asdict(self) + + @staticmethod + def from_dict(data: dict) -> "LayerSkipConfig": + return LayerSkipConfig(**data) + class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode): def __torch_function__(self, func, types, args=(), kwargs=None): diff --git a/src/diffusers/hooks/smoothed_energy_guidance_utils.py b/src/diffusers/hooks/smoothed_energy_guidance_utils.py index 65cce3c53907..622f60764762 100644 --- a/src/diffusers/hooks/smoothed_energy_guidance_utils.py +++ b/src/diffusers/hooks/smoothed_energy_guidance_utils.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from dataclasses import dataclass +from dataclasses import asdict, dataclass from typing import List, Optional import torch @@ -51,6 +51,13 @@ class SmoothedEnergyGuidanceConfig: fqn: str = "auto" _query_proj_identifiers: List[str] = None + def to_dict(self): + return asdict(self) + + @staticmethod + def from_dict(data: dict) -> "SmoothedEnergyGuidanceConfig": + return SmoothedEnergyGuidanceConfig(**data) + class SmoothedEnergyGuidanceHook(ModelHook): def __init__(self, blur_sigma: float = 1.0, blur_threshold_inf: float = 9999.9) -> None: From 2c66fb3a85582f66f68fc2832e1c45b051d9f3b0 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Sat, 5 Jul 2025 14:26:13 -1000 Subject: [PATCH 144/170] Apply suggestions from code review Co-authored-by: Sayak Paul --- .../en/modular_diffusers/getting_started.md | 26 +++++++++---------- src/diffusers/commands/custom_blocks.py | 2 +- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/docs/source/en/modular_diffusers/getting_started.md b/docs/source/en/modular_diffusers/getting_started.md index 4a527b69a165..c74223036429 100644 --- a/docs/source/en/modular_diffusers/getting_started.md +++ b/docs/source/en/modular_diffusers/getting_started.md @@ -10,15 +10,15 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> -# Getting Started with Modular Diffusers: An Comprehensive Overview +# Getting Started with Modular Diffusers: A Comprehensive Overview -With Modular Diffusers, we introduce a unified pipeline system that simplifies how you work with diffusion models. Instead of creating separate pipelines for each task, Modular Diffusers let you: +With Modular Diffusers, we introduce a unified pipeline system that simplifies how you work with diffusion models. Instead of creating separate pipelines for each task, Modular Diffusers lets you: -**Write Only What's New**: You won't need to rewrite the entire pipeline from scratch. You can create pipeline blocks just for your new workflow's unique aspects and reuse existing blocks for existing functionalities. +**Write Only What's New**: You won't need to write an entire pipeline from scratch every time you have a new use case. You can create pipeline blocks just for your new workflow's unique aspects and reuse existing blocks for existing functionalities. -**Assemble Like LEGO®**: You can mix and match blocks in flexible ways. This allows you to write dedicated blocks for specific workflows, and then assemble different blocks into a pipeline that that can be used more conveniently for multiple workflows. +**Assemble Like LEGO®**: You can mix and match between blocks in flexible ways. This allows you to write dedicated blocks unique to specific workflows, and then assemble different blocks into a pipeline that can be used more conveniently for multiple workflows. -In this guide, we will focus on how to build pipelines this way using blocks we officially support at diffusers 🧨! We will show you how to write your own pipeline blocks and go into more details on how they work under the hood in this [guide](./write_own_pipeline_block.md). For advanced users who want to build complete workflows from scratch, we provide an end-to-end example in the [Developer Guide](./end_to_end.md) that covers everything from writing custom pipeline blocks to deploying your workflow as a UI node. +In this guide, we will focus on how to build end-to-end pipelines using blocks we officially support at diffusers 🧨! We will show you how to write your own pipeline blocks and go into more details on how they work under the hood in this [guide](./write_own_pipeline_block.md). For advanced users who want to build complete workflows from scratch, we provide an end-to-end example in the [Developer Guide](./end_to_end.md) that covers everything from writing custom pipeline blocks to deploying your workflow as a UI node. Let's get started! The Modular Diffusers Framework consists of three main components: - ModularPipelineBlocks @@ -40,12 +40,13 @@ It is very easy to use a `ModularPipelineBlocks` officially supported in 🧨 Di ```py from diffusers.modular_pipelines.stable_diffusion_xl import StableDiffusionXLTextEncoderStep + text_encoder_block = StableDiffusionXLTextEncoderStep() ``` This is a single `PipelineBlock`. You'll see that this text encoder block uses 2 text_encoders, 2 tokenizers as well as a guider component. It takes user inputs such as `prompt` and `negative_prompt`, and return text embeddings outputs such as `prompt_embeds` and `negative_prompt_embeds`. -``` +```py >>> text_encoder_block StableDiffusionXLTextEncoderStep( Class: PipelineBlock @@ -211,7 +212,7 @@ You can extract a block instance from the multi-block to use it independently. A >>> text_encoder_blocks ``` -the multi-block now has fewer components and no longer has the `text_encoder` block. If you check its docstring `t2i_blocks.doc`, you will see that it no longer accepts `prompt` as input - you will need to pass the embeddings instead. +The multi-block now has fewer components and no longer has the `text_encoder` block. If you check its docstring `t2i_blocks.doc`, you will see that it no longer accepts `prompt` as input - you will need to pass the embeddings instead. ```py >>> t2i_blocks @@ -294,7 +295,7 @@ This covers the essentials of pipeline blocks! You may have noticed that we have ## PipelineState & BlockState -`PipelineState` and `BlockState` manage dataflow between pipeline blocks. `PipelineState` acts as the global state container that `ModularPipelineBlocks` operate on - each block gets a local view (`BlockState`) of the relevant variables it needs from `PipelineState`, performs its operations, and then updates PipelineState with any changes. +`PipelineState` and `BlockState` manage dataflow between pipeline blocks. `PipelineState` acts as the global state container that `ModularPipelineBlocks` operate on - each block gets a local view (`BlockState`) of the relevant variables it needs from `PipelineState`, performs its operations, and then updates `PipelineState` as needed. @@ -310,11 +311,10 @@ You typically don't need to manually create or manage these state objects. The ` `ModularPipeline` only works with modular repositories. You can find an example modular repo [here](https://huggingface.co/YiYiXu/modular-diffdiff). -Instead of using a `model_index.json` to configure components loading in `DiffusionPipeline`. Modular repositories work with `modular_model_index.json`. Let's walk through the difference here. +A `DiffusionPipeline` defines `model_index.json` to configure its components. However, repositories for Modular Diffusers work with `modular_model_index.json`. Let's walk through the differences here. In standard `model_index.json`, each component entry is a `(library, class)` tuple: -```py "text_encoder": [ "transformers", "CLIPTextModel" @@ -428,7 +428,7 @@ All expected components are now loaded into the pipeline. You can also partially >>> t2i_pipeline.load_components(names=["unet", "vae"], torch_dtype=torch.float16) ``` -You can inspect the pipeline's loading status through its `loader` attribute to understand what components are expected to load, which ones are already loaded, how they were loaded, and what loading specs are available. The loader is synced with the `modular_model_index.json` from the repository you used during `init_pipeline()` - it takes the loading specs that match the pipeline's component requirements. +You can inspect the `loader` attribute of a pipeline to understand what components are expected to load, which ones are already loaded, how they were loaded, and what loading specs are available. The loader is synced with the `modular_model_index.json` from the repository you used during `init_pipeline()` - it takes the loading specs that match the pipeline's component requirements. For example, if your pipeline needs a `text_encoder` component, the loader will include the loading spec for `text_encoder` from the modular repo. If the pipeline doesn't need a component (like `controlnet` in a basic text-to-image pipeline), that component won't appear in the loader even if it exists in the modular repo. @@ -594,7 +594,7 @@ There are also a few properties that can provide a quick summary of component lo ### Modifying Loading Specs -When you call `pipeline.load_components(names=)` or `pipeline.load_default_components()`, it uses the loading specs from the modular repository's `modular_model_index.json`. The pipeline's `loader` attribute is synced with these specs - it shows you exactly what will be loaded and from where. +When you call `pipeline.load_components(names=...)` or `pipeline.load_default_components()`, it uses the loading specs from the modular repository's `modular_model_index.json`. The pipeline's `loader` attribute is synced with these specs - it shows you exactly what will be loaded and from where. You can change where components are loaded from by default by modifying the `modular_model_index.json` in the repository. You can change any field in the loading specs: `repo`, `subfolder`, `variant`, `revision`, etc. @@ -714,7 +714,7 @@ t2i_pipeline.doc #### Text-to-Image, Image-to-Image, and Inpainting -These are minimum inference example for our basic tasks: text-to-image, image-to-image and inpainting. The process to create different pipelines is the same - only difference is the block classes presets. The inference is also more or less same to standard pipelines, but please always check `.doc` for correct input names and remember to pass `output="images"`. +These are minimum inference examples for basic tasks: text-to-image, image-to-image, and inpainting. The process to create different pipelines is the same - only difference is the block classes presets. The inference is also more or less same to standard pipelines, but please always check `.doc` for correct input names and remember to pass `output="images"`. diff --git a/src/diffusers/commands/custom_blocks.py b/src/diffusers/commands/custom_blocks.py index 07fca44678ba..43d9ea88577a 100644 --- a/src/diffusers/commands/custom_blocks.py +++ b/src/diffusers/commands/custom_blocks.py @@ -27,7 +27,7 @@ from . import BaseDiffusersCLICommand -EXPECTED_PARENT_CLASSES = ["PipelineBlock"] +EXPECTED_PARENT_CLASSES = ["ModularPipelineBlocks"] CONFIG = "config.json" From 63e94cbc611bd97a04767c91a207e465ea14f948 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 6 Jul 2025 02:56:00 +0200 Subject: [PATCH 145/170] resolve conflicnt --- .../en/modular_diffusers/end_to_end_guide.md | 6 +- .../en/modular_diffusers/getting_started.md | 218 +++---- .../write_own_pipeline_block.md | 12 +- src/diffusers/__init__.py | 6 +- src/diffusers/modular_pipelines/__init__.py | 6 +- .../modular_pipelines/components_manager.py | 29 +- .../modular_pipelines/modular_pipeline.py | 616 +++++++++--------- .../modular_pipeline_utils.py | 2 +- .../stable_diffusion_xl/__init__.py | 8 +- .../stable_diffusion_xl/before_denoise.py | 42 +- .../stable_diffusion_xl/decoders.py | 4 +- .../stable_diffusion_xl/denoise.py | 18 +- .../stable_diffusion_xl/encoders.py | 21 +- ...ar_blocks_presets.py => modular_blocks.py} | 0 ...{modular_loader.py => modular_pipeline.py} | 8 +- .../pipelines/pipeline_loading_utils.py | 2 +- src/diffusers/utils/dummy_pt_objects.py | 15 - .../dummy_torch_and_transformers_objects.py | 2 +- 18 files changed, 504 insertions(+), 511 deletions(-) rename src/diffusers/modular_pipelines/stable_diffusion_xl/{modular_blocks_presets.py => modular_blocks.py} (100%) rename src/diffusers/modular_pipelines/stable_diffusion_xl/{modular_loader.py => modular_pipeline.py} (99%) diff --git a/docs/source/en/modular_diffusers/end_to_end_guide.md b/docs/source/en/modular_diffusers/end_to_end_guide.md index ab4ba8020da3..132c4870b770 100644 --- a/docs/source/en/modular_diffusers/end_to_end_guide.md +++ b/docs/source/en/modular_diffusers/end_to_end_guide.md @@ -505,7 +505,7 @@ We provide a auto controlnet input block that you can directly put into your wor ```py ->>> from diffusers.modular_pipelines.stable_diffusion_xl.modular_blocks_presets import StableDiffusionXLAutoControlNetInputStep +>>> from diffusers.modular_pipelines.stable_diffusion_xl.modular_blocks import StableDiffusionXLAutoControlNetInputStep >>> control_input_block = StableDiffusionXLAutoControlNetInputStep() >>> print(control_input_block) ``` @@ -613,7 +613,7 @@ to use You can easily share your differential diffusion workflow on the hub, by creating a modular repo like this https://huggingface.co/YiYiXu/modular-diffdiff -To create a Modular Repo and share on hub, you just need to run. Note that if your pipeline contains custom block, you need to manually upload the code to the hub. But we are working on a command line tool to help you upload it very easily. +To create a Modular Repo and share on hub, you just need to run `save_pretrained()` along with the `push_to_hub=True` flag. Note that if your pipeline contains custom block, you need to manually upload the code to the hub. But we are working on a command line tool to help you upload it very easily. ```py dd_pipeline.save_pretrained("YiYiXu/test_modular_doc", push_to_hub=True) @@ -626,7 +626,7 @@ With a modular repo, it is very easy for the community to use the workflow you j >>> import torch >>> from diffusers.utils import load_image >>> ->>> repo_id = "YiYiXu/modular-diffdiff" +>>> repo_id = "YiYiXu/modular-diffdiff-0704" >>> >>> components = ComponentsManager() >>> diff --git a/docs/source/en/modular_diffusers/getting_started.md b/docs/source/en/modular_diffusers/getting_started.md index c74223036429..ff1633988e5f 100644 --- a/docs/source/en/modular_diffusers/getting_started.md +++ b/docs/source/en/modular_diffusers/getting_started.md @@ -31,10 +31,12 @@ Pipeline blocks are the fundamental building blocks of the Modular Diffusers sys - [`PipelineBlock`]: The most granular block - you define the computation logic. - [`SequentialPipelineBlocks`]: A multi-block composed of multiple blocks that run sequentially, passing outputs as inputs to the next block. -- [`LoopSequentialPipelineBlocks`]: A special type of multi-block that forms loops. +- [`LoopSequentialPipelineBlocks`]: A special type of `SequentialPipelineBlocks` that runs the same sequence of blocks multiple times (loops), typically used for iterative processes like denoising steps in diffusion models. - [`AutoPipelineBlocks`]: A multi-block composed of multiple blocks that are selected at runtime based on the inputs. -All blocks have a consistent interface defining their requirements (components, configs, inputs, outputs) and computation logic. They can be used standalone or combined into larger blocks. Blocks are designed to be assembled into workflows for tasks such as image generation, video creation, and inpainting. +All blocks have a consistent interface defining their requirements (components, configs, inputs, outputs) and computation logic. They can be defined standalone or combined into larger blocks - They are designed to be assembled into workflows for tasks such as image generation, video creation, and inpainting. However, blocks aren't runnable on thier own and they need to be converted into a a ModularPipeline to actually run. + +**Blocks vs Pipelines**: Blocks are just definitions - they define what components, inputs/outputs, and computation logics are needed, but they don't actually run anything. To execute blocks, you need to put them into a `ModularPipeline`. See the [ModularPipeline from ModularPipelineBlocks](#modularpipeline-from-modularpipelineblocks) section for how to create and run pipelines. It is very easy to use a `ModularPipelineBlocks` officially supported in 🧨 Diffusers @@ -321,10 +323,10 @@ In standard `model_index.json`, each component entry is a `(library, class)` tup ], ``` -In `modular_model_index.json`, each component entry contains 3 elements: `(library, class, loading_specs {})` +In `modular_model_index.json`, each component entry contains 3 elements: `(library, class, loading_specs_dict)` - `library` and `class`: Information about the actual component loaded in the pipeline at the time of saving (will be `null` if not loaded) -- `loading_specs`: A dictionary containing all information required to load this component, including `repo`, `revision`, `subfolder`, `variant`, and `type_hint`. +- `loading_specs_dict`: A dictionary containing all information required to load this component, including `repo`, `revision`, `subfolder`, `variant`, and `type_hint`. ```py "text_encoder": [ @@ -342,21 +344,8 @@ In `modular_model_index.json`, each component entry contains 3 elements: `(libra } ], ``` -Some components may not have `repo` field, they cannot be loaded from a repository and can only be created with default config from the pipeline -```py - "image_processor": [ - "diffusers", - "VaeImageProcessor", - { - "type_hint": [ - "diffusers", - "VaeImageProcessor" - ] - } - ], -``` -Unlike standard repositories where components must be in subfolders within the same repo, modular repositories can fetch components from different repositories based on the `loading_specs` dictionary. e.g. the `text_encoder` component will be fetched from the "text_encoder" folder in `stabilityai/stable-diffusion-xl-base-1.0` while other components come from different repositories. +Unlike standard repositories where components must be in subfolders within the same repo, modular repositories can fetch components from different repositories based on the `loading_specs_dict`. e.g. the `text_encoder` component will be fetched from the "text_encoder" folder in `stabilityai/stable-diffusion-xl-base-1.0` while other components come from different repositories. ### Creating a `ModularPipeline` from `ModularPipelineBlocks` @@ -370,7 +359,7 @@ Let's convert our `t2i_blocks` (which we created earlier) into a runnable `Modul t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS) # Now convert it to a ModularPipeline -modular_repo_id = "YiYiXu/modular-loader-t2i" +modular_repo_id = "YiYiXu/modular-loader-t2i-0704" t2i_pipeline = t2i_blocks.init_pipeline(modular_repo_id) ``` @@ -398,22 +387,36 @@ You can read more about Components Manager [here](TODO) You can create a `ModularPipeline` from a HuggingFace Hub repository with `from_pretrained` method, as long as it's a modular repo: ```py -# YiYi TODO: this is not yet supported actually 😢, need to add support from diffusers import ModularPipeline -pipeline = ModularPipeline.from_pretrained(repo_id, components_manager=..., collection=...) +pipeline = ModularPipeline.from_pretrained( "YiYiXu/modular-loader-t2i-0704") ``` Loading custom code is also supported: ```py from diffusers import ModularPipeline -modular_repo_id = "YiYiXu/modular-diffdiff" +modular_repo_id = "YiYiXu/modular-diffdiff-0704" diffdiff_pipeline = ModularPipeline.from_pretrained(modular_repo_id, trust_remote_code=True) ``` +This modular repository contains custom code. The [`config.json`](https://huggingface.co/YiYiXu/modular-diffdiff-0704/blob/main/config.json) file defines a custom `DiffDiffBlocks` class and points to its implementation: + +```json +{ + "_class_name": "DiffDiffBlocks", + "auto_map": { + "ModularPipelineBlocks": "block.DiffDiffBlocks" + } +} +``` + +The `auto_map` tells the pipeline where to find the custom blocks definition - in this case, it's looking for `DiffDiffBlocks` in the `block.py` file. The actual `DiffDiffBlocks` class is defined in [`block.py`](https://huggingface.co/YiYiXu/modular-diffdiff-0704/blob/main/block.py) within the repository. + +When `diffdiff_pipeline.blocks` is created, it's based on the `DiffDiffBlocks` definition from the custom code in the repository, allowing you to use specialized blocks that aren't part of the standard diffusers library. + ### Loading components into a `ModularPipeline` -Unlike `DiffusionPipeline`, when you create a `ModularPipeline` instance (whether using `from_pretrained` or converting from pipeline blocks), its components aren't loaded automatically. You need to explicitly load model components using `load_components`: +Unlike `DiffusionPipeline`, when you create a `ModularPipeline` instance (whether using `from_pretrained` or converting from pipeline blocks), its components aren't loaded automatically. You need to explicitly load model components using `load_default_components` or `load_components(names=..,)`: ```py # This will load ALL the expected components into pipeline @@ -428,49 +431,15 @@ All expected components are now loaded into the pipeline. You can also partially >>> t2i_pipeline.load_components(names=["unet", "vae"], torch_dtype=torch.float16) ``` -You can inspect the `loader` attribute of a pipeline to understand what components are expected to load, which ones are already loaded, how they were loaded, and what loading specs are available. The loader is synced with the `modular_model_index.json` from the repository you used during `init_pipeline()` - it takes the loading specs that match the pipeline's component requirements. - -For example, if your pipeline needs a `text_encoder` component, the loader will include the loading spec for `text_encoder` from the modular repo. If the pipeline doesn't need a component (like `controlnet` in a basic text-to-image pipeline), that component won't appear in the loader even if it exists in the modular repo. - -The loader has the same structure as `modular_model_index.json` - each component entry contains the `(library, class, loading_specs)` format. You'll need to understand that structure to properly read the loading status below. - - - -💡 **How to read the loader**: -- **`library` and `class` fields**: Show info about actually loaded components. If `null`, the component is not loaded yet. -- **`loading_specs`**: If it does not have `repo` field or if it is `null`, the component cannot be loaded from a repository and can only be created with default config by the pipeline. - - - -Let's inspect the `t2i_pipeline.loader`, you can see all the components expected to load are listed as entries in the loader. The `guider` and `image_processor` components were created using default config (their `library` and `class` field are populated, this means they are initialized, but their loading spec dict is missing loading related fields). The `vae` and `unet` components were loaded using their respective loading specs. The rest of the components (scheduler, text_encoder, text_encoder_2, tokenizer, tokenizer_2) are not loaded yet (their `library`, `class` fields are `null`), but you can examine their loading specs to see where they would be loaded from when you call `load_components()`. - +You can inspect the pipeline's loading status by simply printing the pipeline itself. It helps you understand what components are expected to load, which ones are already loaded, how they were loaded, and what loading specs are available. Let's print out the `t2i_pipeline`: ```py ->>> t2i_pipeline.loader -StableDiffusionXLModularLoader { - "_class_name": "StableDiffusionXLModularLoader", - "_diffusers_version": "0.34.0.dev0", +>>> t2i_pipeline +StableDiffusionXLModularPipeline { + "_blocks_class_name": "SequentialPipelineBlocks", + "_class_name": "StableDiffusionXLModularPipeline", + "_diffusers_version": "0.35.0.dev0", "force_zeros_for_empty_prompt": true, - "guider": [ - "diffusers", - "ClassifierFreeGuidance", - { - "type_hint": [ - "diffusers", - "ClassifierFreeGuidance" - ] - } - ], - "image_processor": [ - "diffusers", - "VaeImageProcessor", - { - "type_hint": [ - "diffusers", - "VaeImageProcessor" - ] - } - ], "scheduler": [ null, null, @@ -572,31 +541,42 @@ StableDiffusionXLModularLoader { } ``` +You can see all the components that will be loaded using `from_pretrained` method are listed as entries. Each entry contains 3 elements: `(library, class, loading_specs_dict)`: + +- **`library` and `class`**: Show the actual loaded component info. If `null`, the component is not loaded yet. +- **`loading_specs_dict`**: Contains all the information needed to load the component (repo, subfolder, variant, etc.) + +In this example: +- **Loaded components**: `vae` and `unet` (their `library` and `class` fields show the actual loaded models) +- **Not loaded yet**: `scheduler`, `text_encoder`, `text_encoder_2`, `tokenizer`, `tokenizer_2` (their `library` and `class` fields are `null`, but you can see their loading specs to know where they'll be loaded from when you call `load_components()`) + +You're looking at essentailly the pipeline's config dict that's synced with the `modular_model_index.json` from the repository you used during `init_pipeline()` - it takes the loading specs that match the pipeline's component requirements. + +For example, if your pipeline needs a `text_encoder` component, it will include the loading spec for `text_encoder` from the modular repo during the `init_pipeline`. If the pipeline doesn't need a component (like `controlnet` in a basic text-to-image pipeline), that component won't be included even if it exists in the modular repo. + There are also a few properties that can provide a quick summary of component loading status: ```py # All components expected by the pipeline ->>> t2i_pipeline.loader.component_names +>>> t2i_pipeline.component_names ['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'guider', 'scheduler', 'unet', 'vae', 'image_processor'] # Components that are not loaded yet (will be loaded with from_pretrained) ->>> t2i_pipeline.loader.null_component_names +>>> t2i_pipeline.null_component_names ['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'scheduler'] # Components that will be loaded from pretrained models ->>> t2i_pipeline.loader.pretrained_component_names +>>> t2i_pipeline.pretrained_component_names ['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'scheduler', 'unet', 'vae'] # Components that are created with default config (no repo needed) ->>> t2i_pipeline.loader.config_component_names +>>> t2i_pipeline.config_component_names ['guider', 'image_processor'] ``` ### Modifying Loading Specs -When you call `pipeline.load_components(names=...)` or `pipeline.load_default_components()`, it uses the loading specs from the modular repository's `modular_model_index.json`. The pipeline's `loader` attribute is synced with these specs - it shows you exactly what will be loaded and from where. - -You can change where components are loaded from by default by modifying the `modular_model_index.json` in the repository. You can change any field in the loading specs: `repo`, `subfolder`, `variant`, `revision`, etc. +When you call `pipeline.load_components(names=)` or `pipeline.load_default_components()`, it uses the loading specs from the modular repository's `modular_model_index.json`. You can change where components are loaded from by default by modifying the `modular_model_index.json` in the repository. You can change any field in the loading specs: `repo`, `subfolder`, `variant`, `revision`, etc. ```py # Original spec in modular_model_index.json @@ -682,6 +662,31 @@ StableDiffusionXLModularLoader { ... } ``` + + +💡 **Modifying Component Specs**: You can get a copy of the current component spec from the pipeline using `get_component_spec()`. This makes it easy to modify the spec and updating components. + +```py +>>> unet_spec = t2i_pipeline.get_component_spec("unet") +>>> unet_spec +ComponentSpec( + name='unet', + type_hint=, + repo='RunDiffusion/Juggernaut-XL-v9', + subfolder='unet', + variant='fp16', + default_creation_method='from_pretrained' +) + +# Modify the spec to load from a different repository +>>> unet_spec.repo = "stabilityai/stable-diffusion-xl-base-1.0" + +# Load the component with the modified spec +>>> unet = unet_spec.load() +``` + + + ### Running a `ModularPipeline` @@ -728,7 +733,7 @@ from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS # create pipeline from official blocks preset blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS) -modular_repo_id = "YiYiXu/modular-loader-t2i" +modular_repo_id = "YiYiXu/modular-loader-t2i-0704" pipeline = blocks.init_pipeline(modular_repo_id) pipeline.load_default_components(torch_dtype=torch.float16) @@ -750,7 +755,7 @@ from diffusers.modular_pipelines.stable_diffusion_xl import IMAGE2IMAGE_BLOCKS # create pipeline from blocks preset blocks = SequentialPipelineBlocks.from_blocks_dict(IMAGE2IMAGE_BLOCKS) -modular_repo_id = "YiYiXu/modular-loader-t2i" +modular_repo_id = "YiYiXu/modular-loader-t2i-0704" pipeline = blocks.init_pipeline(modular_repo_id) pipeline.load_default_components(torch_dtype=torch.float16) @@ -775,7 +780,7 @@ from diffusers.utils import load_image # create pipeline from blocks preset blocks = SequentialPipelineBlocks.from_blocks_dict(INPAINT_BLOCKS) -modular_repo_id = "YiYiXu/modular-loader-t2i" +modular_repo_id = "YiYiXu/modular-loader-t2i-0704" pipeline = blocks.init_pipeline(modular_repo_id) pipeline.load_default_components(torch_dtype=torch.float16) @@ -809,7 +814,7 @@ For ControlNet, we provide one auto block you can place at the `denoise` step. L >>> from diffusers.modular_pipelines.stable_diffusion_xl import ALL_BLOCKS >>> ALL_BLOCKS["controlnet"] InsertableDict([ - 0: ('denoise', ) + 0: ('denoise', ) ]) >>> controlnet_blocks = ALL_BLOCKS["controlnet"]["denoise"]() >>> controlnet_blocks @@ -899,7 +904,7 @@ Let's walk through the steps: >>> from diffusers.modular_pipelines.stable_diffusion_xl import ALL_BLOCKS >>> ALL_BLOCKS["ip_adapter"] InsertableDict([ - 0: ('ip_adapter', ) + 0: ('ip_adapter', ) ]) ``` @@ -932,8 +937,7 @@ StableDiffusionXLAutoIPAdapterStep( Sub-Blocks: • ip_adapter [trigger: ip_adapter_image] (StableDiffusionXLIPAdapterStep) Description: IP Adapter step that prepares ip adapter image embeddings. - Note that this step only prepares the embeddings - in order for it to work correctly, you need to load ip adapter weights into unet via ModularPipeline.loader. - e.g. pipeline.loader.load_ip_adapter() and pipeline.loader.set_ip_adapter_scale(). + Note that this step only prepares the embeddings - in order for it to work correctly, you need to load ip adapter weights into unet via ModularPipeline.load_ip_adapter() and pipeline.set_ip_adapter_scale(). See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin) for more details ) @@ -958,12 +962,12 @@ modular_repo_id = "YiYiXu/modular-demo-auto" pipeline = blocks.init_pipeline(modular_repo_id) pipeline.load_default_components(torch_dtype=torch.float16) -pipeline.loader.load_ip_adapter( +pipeline.load_ip_adapter( "h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin" ) -pipeline.loader.set_ip_adapter_scale(0.8) +pipeline.set_ip_adapter_scale(0.8) pipeline.to("cuda") ``` @@ -1020,31 +1024,23 @@ components = ComponentsManager() components.enable_auto_cpu_offload(device="cuda") ``` -Since we have a modular setup where different pipelines may share components, we recommend using a standalone loader to load components all at once and add them to each pipeline with `update_components()`. - - - +Since we have a modular setup where different pipelines may share components, we recommend using a seperate `ModularPipeline` to load components all at once and add them to each pipeline with `update_components()`. -💡 **Load components without pipeline blocks**: -- `blocks.init_pipeline(repo)` creates a pipeline with a built-in loader that only includes components its blocks needs -- `StableDiffusionXLModularLoader.from_pretrained(repo)` set up a standalone loader that includes everything in the repo's `modular_model_index.json` - - ```py -from diffusers import StableDiffusionXLModularLoader +from diffusers import ModularPipeline t2i_repo = "YiYiXu/modular-demo-auto" -t2i_loader = StableDiffusionXLModularLoader.from_pretrained(t2i_repo, components_manager=components, collection="t2i") +t2i_loader_pipe = ModularPipeline.from_pretrained(t2i_repo, components_manager=components, collection="t2i") text_node = text_blocks.init_pipeline(t2i_repo, components_manager=components) decoder_node = decoder_blocks.init_pipeline(t2i_repo, components_manager=components) t2i_pipe = t2i_blocks.init_pipeline(t2i_repo, components_manager=components) ``` -We'll load components in `t2i_loader`. You can get the list of all loadable components from loader's `pretrained_component_names` property. +We'll load components in `t2i_loader_pipe`. You can get the list of all loadable components from loader's `pretrained_component_names` property. ```py ->>> t2i_loader.pretrained_component_names +>>> t2i_loader_pipe.pretrained_component_names ['controlnet', 'image_encoder', 'scheduler', 'text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'unet', 'vae'] ``` @@ -1054,7 +1050,7 @@ It include controlnet and image_encoder for ip-adapter that we don't need now. B import torch # inspect before you load # t2i_loader -t2i_loader.load(t2i_loader.pretrained_component_names, torch_dtype=torch.float16) +t2i_loader_pipe.load_components(names=t2i_loader_pipe.pretrained_component_names, torch_dtype=torch.float16) ``` All the models are registered to components manager under the collection "t2i". @@ -1088,15 +1084,15 @@ Additional Component Info: ``` Let's add the loaded components to each pipeline. We'll follow this pattern for each pipeline: -1. Check what components the pipeline needs: inspect `pipeline.loader` or use `loader.null_component_names` +1. Check what components the pipeline needs: inspect `pipeline` or use `pipeline.null_component_names` 2. Get them from the components manager: use its `search_models()`/`get_one`/`get_components_from_names` method 3. Update the pipeline: `pipeline.update_components()` -4. Verify the components are loaded correctly: inspect `pipeline.loader` as well as components manager +4. Verify the components are loaded correctly: inspect `pipeline` as well as components manager We will start with `decoder_node`. First, check what components it needs: ```py ->>> decoder_node.loader.null_component_names +>>> decoder_node.null_component_names ['vae'] ``` The pipeline only needs a `vae`. Looking at the components manager table, there's only one VAE available: @@ -1116,24 +1112,24 @@ decoder_node.update_components(vae=vae) Verify it's correctly loaded: ```py -decoder_node.loader +decoder_node ``` Now let's do the same for `text_node`. Get the list of components the pipeline needs to load: ```py ->>> text_node.loader.null_component_names +>>> text_node.null_component_names ['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2'] ``` Pass the list directly to the components manager to get the components and add it to the pipeline ```py -text_components = components.get_components_by_names(text_node.loader.null_component_names) +text_components = components.get_components_by_names(text_node.null_component_names) # Add components to pipeline text_node.update_components(**text_components) # Verify components are loaded -assert not text_node.loader.null_component_names -text_node.loader +assert not text_node.null_component_names +text_node ``` Finally, let's set up `t2i_pipe`: @@ -1141,12 +1137,12 @@ Finally, let's set up `t2i_pipe`: ```py # Get unet & scheduler from components manager and add to pipeline -comps = components.get_components_by_names(t2i_pipe.loader.null_component_names) +comps = components.get_components_by_names(t2i_pipe.null_component_names) t2i_pipe.update_components(**comps) # Verify everything is loaded -assert not t2i_pipe.loader.null_component_names -t2i_pipe.loader +assert not t2i_pipe.null_component_names +t2i_pipe # Verify components manager hasn't changed (we only reused existing components) components @@ -1183,7 +1179,7 @@ image.save("modular_part2_t2i.png") Now let's add a LoRA to our pipeline. With the modular approach we will be able to reuse intermediate outputs from blocks that otherwise needs to be re-run. Let's load the LoRA weights and see what happens: ```py -t2i_loader.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy_face") +t2i_loader_pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy_face") components ``` Notice that the "Additional Component Info" section shows that only the `unet` component has the LoRA adapter loaded. This means we can skip the text encoding step and reuse the existing embeddings, making the generation much faster. @@ -1231,12 +1227,12 @@ ipa_node = ipa_blocks.init_pipeline(t2i_repo, components_manager=components) comps = components.get_components_by_names(ipa_node.loader.null_component_names) ipa_node.update_components(**comps) -t2i_loader.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin") -t2i_loader.set_ip_adapter_scale(0.6) +t2i_loader_pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin") +t2i_loader_pipe.set_ip_adapter_scale(0.6) # check it's correctly loaded -assert not ipa_node.loader.null_component_names -ipa_node.loader +assert not ipa_node.null_component_names +ipa_node # find out inputs/outputs print(ipa_node.doc) @@ -1305,7 +1301,7 @@ refiner_pipe = refiner_blocks.init_pipeline(refiner_repo, components_manager=com We want to reuse components from the t2i pipeline in the refiner as much as possible. First, let's check the loading status of the refiner pipeline to understand what components are needed: ```py ->>> refiner_pipe.loader +>>> refiner_pipe ``` Looking at the loader output, you can see that `text_encoder` and `tokenizer` have empty loading spec maps (their `repo` fields are `null`), this is because refiner pipeline does not use these two components so they are not listed in the `modular_model_index.json` in `refiner_repo`. The `unet` is different from the one we loaded for text-to-image. The remaining components: `vae`, `text_encoder_2`, `tokenizer_2`, and `scheduler` are already available in the t2i collection, we can reuse them instead of loading duplicates. @@ -1314,7 +1310,7 @@ Looking at the loader output, you can see that `text_encoder` and `tokenizer` ha refiner_pipe.load_components(names="unet", torch_dtype=torch.float16) # verify loaded correctly -refiner_pipe.loader +refiner_pipe # veryfiy registered to components manager under refiner components diff --git a/docs/source/en/modular_diffusers/write_own_pipeline_block.md b/docs/source/en/modular_diffusers/write_own_pipeline_block.md index 4739bbc6901a..f65af4463ff9 100644 --- a/docs/source/en/modular_diffusers/write_own_pipeline_block.md +++ b/docs/source/en/modular_diffusers/write_own_pipeline_block.md @@ -107,7 +107,7 @@ def __call__(self, components, state): # You can access them like: block_state.image, block_state.processed_image # Update the pipeline state with your updated block_states - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state ``` @@ -140,7 +140,7 @@ When you convert your blocks into a pipeline using `blocks.init_pipeline()`, the That's all you need to define in order to create a `PipelineBlock`. There is no hidden complexity. In fact we are going to create a helper function that take exactly these variables as input and return a pipeline block. We will use this helper function through out the tutorial to create test blocks -Note that for `__call__` method, the only part you should implement differently is the part between `self.get_block_state()` and `self.add_block_state()`, which can be abstracted into a simple function that takes `block_state` and returns the updated state. Our helper function accepts a `block_fn` that does exactly that. +Note that for `__call__` method, the only part you should implement differently is the part between `self.get_block_state()` and `self.set_block_state()`, which can be abstracted into a simple function that takes `block_state` and returns the updated state. Our helper function accepts a `block_fn` that does exactly that. **Helper Function** @@ -172,7 +172,7 @@ def make_block(inputs=[], intermediate_inputs=[], intermediate_outputs=[], block block_state = self.get_block_state(state) if block_fn is not None: block_state = block_fn(block_state, state) - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state return TestBlock @@ -403,7 +403,7 @@ class DenoiseLoop(PipelineBlock): for t in range(block_state.num_inference_steps): # ... loop logic here pass - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state ``` @@ -455,7 +455,7 @@ class LoopWrapper(LoopSequentialPipelineBlocks): for i in range(block_state.num_steps): # loop_step executes all registered blocks in sequence components, block_state = self.loop_step(components, block_state, i=i) - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state ``` @@ -464,7 +464,7 @@ class LoopWrapper(LoopSequentialPipelineBlocks): Loop blocks are standard `PipelineBlock`s, but their `__call__` method works differently: * It receives the iteration variable (e.g., `i`) passed by the loop wrapper * It works directly with `block_state` instead of pipeline state -* No need to call `self.get_block_state()` or `self.add_block_state()` +* No need to call `self.get_block_state()` or `self.set_block_state()` ```py class LoopBlock(PipelineBlock): diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 314a4126d2bd..885d37fc8eca 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -240,7 +240,6 @@ [ "ComponentsManager", "ComponentSpec", - "ModularLoader", "ModularPipeline", "ModularPipelineBlocks", ] @@ -360,7 +359,7 @@ _import_structure["modular_pipelines"].extend( [ "StableDiffusionXLAutoBlocks", - "StableDiffusionXLModularLoader", + "StableDiffusionXLModularPipeline", ] ) _import_structure["pipelines"].extend( @@ -881,7 +880,6 @@ from .modular_pipelines import ( ComponentsManager, ComponentSpec, - ModularLoader, ModularPipeline, ModularPipelineBlocks, ) @@ -983,7 +981,7 @@ else: from .modular_pipelines import ( StableDiffusionXLAutoBlocks, - StableDiffusionXLModularLoader, + StableDiffusionXLModularPipeline, ) from .pipelines import ( AllegroPipeline, diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index 9b18c8b048f9..bf34eed28b8c 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -29,7 +29,6 @@ "AutoPipelineBlocks", "SequentialPipelineBlocks", "LoopSequentialPipelineBlocks", - "ModularLoader", "PipelineState", "BlockState", ] @@ -40,7 +39,7 @@ "OutputParam", "InsertableDict", ] - _import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularLoader"] + _import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"] _import_structure["components_manager"] = ["ComponentsManager"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -55,7 +54,6 @@ AutoPipelineBlocks, BlockState, LoopSequentialPipelineBlocks, - ModularLoader, ModularPipeline, ModularPipelineBlocks, PipelineBlock, @@ -71,7 +69,7 @@ ) from .stable_diffusion_xl import ( StableDiffusionXLAutoBlocks, - StableDiffusionXLModularLoader, + StableDiffusionXLModularPipeline, ) else: import sys diff --git a/src/diffusers/modular_pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py index a1bdd86e8cdc..cf6501ad2799 100644 --- a/src/diffusers/modular_pipelines/components_manager.py +++ b/src/diffusers/modular_pipelines/components_manager.py @@ -38,26 +38,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -# YiYi Notes: copied from modeling_utils.py (decide later where to put this) -def get_memory_footprint(self, return_buffers=True): - r""" - Get the memory footprint of a model. This will return the memory footprint of the current model in bytes. Useful to - benchmark the memory footprint of the current model and design some tests. Solution inspired from the PyTorch - discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2 - - Arguments: - return_buffers (`bool`, *optional*, defaults to `True`): - Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers are - tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch norm - layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2 - """ - mem = sum([param.nelement() * param.element_size() for param in self.parameters()]) - if return_buffers: - mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()]) - mem = mem + mem_bufs - return mem - - class CustomOffloadHook(ModelHook): """ A hook that offloads a model on the CPU until its forward pass is called. It ensures the model and its inputs are @@ -170,6 +150,8 @@ class AutoOffloadStrategy: the available memory on the device. """ + # YiYi TODO: instead of memory_reserve_margin, we should let user set the maximum_total_models_size to keep on device + # the actual memory usage would be higher. But it's simpler this way, and can be tested def __init__(self, memory_reserve_margin="3GB"): self.memory_reserve_margin = convert_file_size_to_int(memory_reserve_margin) @@ -177,7 +159,7 @@ def __call__(self, hooks, model_id, model, execution_device): if len(hooks) == 0: return [] - current_module_size = get_memory_footprint(model) + current_module_size = model.get_memory_footprint() mem_on_device = torch.cuda.mem_get_info(execution_device.index)[0] mem_on_device = mem_on_device - self.memory_reserve_margin @@ -190,12 +172,13 @@ def __call__(self, hooks, model_id, model, execution_device): # exlucde models that's not currently loaded on the device module_sizes = dict( sorted( - {hook.model_id: get_memory_footprint(hook.model) for hook in hooks}.items(), + {hook.model_id: hook.model.get_memory_footprint() for hook in hooks}.items(), key=lambda x: x[1], reverse=True, ) ) + # YiYi/Dhruv TODO: sort smallest to largest, and offload in that order we would tend to keep the larger models on GPU more often def search_best_candidate(module_sizes, min_memory_offload): """ search the optimal combination of models to offload to cpu, given a dictionary of module sizes and a @@ -652,7 +635,7 @@ def get_model_info( info.update( { "class_name": component.__class__.__name__, - "size_gb": get_memory_footprint(component) / (1024**3), + "size_gb": component.get_memory_footprint() / (1024**3), "adapters": None, # Default to None "has_hook": has_hook, "execution_device": execution_device, diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 99db80d3151c..d0429a1f45bf 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -19,6 +19,7 @@ from collections import OrderedDict from copy import deepcopy from dataclasses import dataclass, field +from types import SimpleNamespace from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -55,9 +56,15 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -MODULAR_LOADER_MAPPING = OrderedDict( +MODULAR_PIPELINE_MAPPING = OrderedDict( [ - ("stable-diffusion-xl", "StableDiffusionXLModularLoader"), + ("stable-diffusion-xl", "StableDiffusionXLModularPipeline"), + ] +) + +MODULAR_PIPELINE_BLOCKS_MAPPING = OrderedDict( + [ + ("StableDiffusionXLModularPipeline", "StableDiffusionXLAutoBlocks"), ] ) @@ -73,7 +80,7 @@ class PipelineState: input_kwargs: Dict[str, List[str]] = field(default_factory=dict) intermediate_kwargs: Dict[str, List[str]] = field(default_factory=dict) - def add_input(self, key: str, value: Any, kwargs_type: str = None): + def set_input(self, key: str, value: Any, kwargs_type: str = None): """ Add an input to the pipeline state with optional metadata. @@ -89,7 +96,7 @@ def add_input(self, key: str, value: Any, kwargs_type: str = None): else: self.input_kwargs[kwargs_type].append(key) - def add_intermediate(self, key: str, value: Any, kwargs_type: str = None): + def set_intermediate(self, key: str, value: Any, kwargs_type: str = None): """ Add an intermediate value to the pipeline state with optional metadata. @@ -329,25 +336,18 @@ def init_pipeline( collection: Optional[str] = None, ): """ - create a ModularLoader, optionally accept modular_repo to load from hub. + create a ModularPipeline, optionally accept modular_repo to load from hub. """ - loader_class_name = MODULAR_LOADER_MAPPING.get(self.model_name, ModularLoader.__name__) + pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(self.model_name, ModularPipeline.__name__) diffusers_module = importlib.import_module("diffusers") - loader_class = getattr(diffusers_module, loader_class_name) - - # Create deep copies to avoid modifying the original specs - component_specs = deepcopy(self.expected_components) - config_specs = deepcopy(self.expected_configs) - # Create the loader with the updated specs - specs = component_specs + config_specs + pipeline_class = getattr(diffusers_module, pipeline_class_name) - loader = loader_class( - specs=specs, + modular_pipeline = pipeline_class( + blocks=deepcopy(self), pretrained_model_name_or_path=pretrained_model_name_or_path, components_manager=components_manager, collection=collection, ) - modular_pipeline = ModularPipeline(blocks=deepcopy(self), loader=loader) return modular_pipeline @@ -512,12 +512,12 @@ def get_block_state(self, state: PipelineState) -> dict: data[input_param.kwargs_type][k] = v return BlockState(**data) - def add_block_state(self, state: PipelineState, block_state: BlockState): + def set_block_state(self, state: PipelineState, block_state: BlockState): for output_param in self.intermediate_outputs: if not hasattr(block_state, output_param.name): raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") param = getattr(block_state, output_param.name) - state.add_intermediate(output_param.name, param, output_param.kwargs_type) + state.set_intermediate(output_param.name, param, output_param.kwargs_type) for input_param in self.intermediate_inputs: if hasattr(block_state, input_param.name): @@ -525,7 +525,7 @@ def add_block_state(self, state: PipelineState, block_state: BlockState): # Only add if the value is different from what's in the state current_value = state.get_intermediate(input_param.name) if current_value is not param: # Using identity comparison to check if object was modified - state.add_intermediate(input_param.name, param, input_param.kwargs_type) + state.set_intermediate(input_param.name, param, input_param.kwargs_type) for input_param in self.intermediate_inputs: if input_param.name and hasattr(block_state, input_param.name): @@ -533,7 +533,7 @@ def add_block_state(self, state: PipelineState, block_state: BlockState): # Only add if the value is different from what's in the state current_value = state.get_intermediate(input_param.name) if current_value is not param: # Using identity comparison to check if object was modified - state.add_intermediate(input_param.name, param, input_param.kwargs_type) + state.set_intermediate(input_param.name, param, input_param.kwargs_type) elif input_param.kwargs_type: # if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters # we need to first find out which inputs are and loop through them. @@ -541,7 +541,7 @@ def add_block_state(self, state: PipelineState, block_state: BlockState): for param_name, current_value in intermediate_kwargs.items(): param = getattr(block_state, param_name) if current_value is not param: # Using identity comparison to check if object was modified - state.add_intermediate(param_name, param, input_param.kwargs_type) + state.set_intermediate(param_name, param, input_param.kwargs_type) def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]: @@ -610,7 +610,6 @@ def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> return list(combined_dict.values()) -# YiYi TODO: change blocks attribute to a different name, so it is not confused with the blocks attribute in ModularPipeline class AutoPipelineBlocks(ModularPipelineBlocks): """ A class that automatically selects a block to run based on the inputs. @@ -1524,12 +1523,12 @@ def get_block_state(self, state: PipelineState) -> dict: data[input_param.kwargs_type][k] = v return BlockState(**data) - def add_block_state(self, state: PipelineState, block_state: BlockState): + def set_block_state(self, state: PipelineState, block_state: BlockState): for output_param in self.intermediate_outputs: if not hasattr(block_state, output_param.name): raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") param = getattr(block_state, output_param.name) - state.add_intermediate(output_param.name, param, output_param.kwargs_type) + state.set_intermediate(output_param.name, param, output_param.kwargs_type) for input_param in self.intermediate_inputs: if input_param.name and hasattr(block_state, input_param.name): @@ -1537,7 +1536,7 @@ def add_block_state(self, state: PipelineState, block_state: BlockState): # Only add if the value is different from what's in the state current_value = state.get_intermediate(input_param.name) if current_value is not param: # Using identity comparison to check if object was modified - state.add_intermediate(input_param.name, param, input_param.kwargs_type) + state.set_intermediate(input_param.name, param, input_param.kwargs_type) elif input_param.kwargs_type: # if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters # we need to first find out which inputs are and loop through them. @@ -1547,7 +1546,7 @@ def add_block_state(self, state: PipelineState, block_state: BlockState): continue param = getattr(block_state, param_name) if current_value is not param: # Using identity comparison to check if object was modified - state.add_intermediate(param_name, param, input_param.kwargs_type) + state.set_intermediate(param_name, param, input_param.kwargs_type) @property def doc(self): @@ -1643,23 +1642,217 @@ def set_progress_bar_config(self, **kwargs): # 2. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess) # 3. do we need ConfigSpec? seems pretty unnecessrary for loader, can just add and kwargs to the loader # 4. add validator for methods where we accpet kwargs to be passed to from_pretrained() -class ModularLoader(ConfigMixin, PushToHubMixin): +class ModularPipeline(ConfigMixin, PushToHubMixin): """ - Base class for all Modular pipelines loaders. + Base class for all Modular pipelines. + Args: + blocks: ModularPipelineBlocks, the blocks to be used in the pipeline """ config_name = "modular_model_index.json" hf_device_map = None + # YiYi TODO: add warning for passing multiple ComponentSpec/ConfigSpec with the same name + def __init__( + self, + blocks: Optional[ModularPipelineBlocks] = None, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, + components_manager: Optional[ComponentsManager] = None, + collection: Optional[str] = None, + **kwargs, + ): + """ + Initialize the loader with a list of component specs and config specs. + """ + if blocks is None: + blocks_class_name = MODULAR_PIPELINE_BLOCKS_MAPPING.get(self.__class__.__name__) + if blocks_class_name is not None: + diffusers_module = importlib.import_module("diffusers") + blocks_class = getattr(diffusers_module, blocks_class_name) + blocks = blocks_class() + else: + logger.warning(f"`blocks` is `None`, no default blocks class found for {self.__class__.__name__}") + + self.blocks = blocks + self._components_manager = components_manager + self._collection = collection + self._component_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_components} + self._config_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_configs} + + # update component_specs and config_specs from modular_repo + if pretrained_model_name_or_path is not None: + config_dict = self.load_config(pretrained_model_name_or_path, **kwargs) + + for name, value in config_dict.items(): + # all the components in modular_model_index.json are from_pretrained components + if ( + name in self._component_specs + and isinstance(value, (tuple, list)) + and len(value) == 3 + ): + library, class_name, component_spec_dict = value + component_spec = self._dict_to_component_spec(name, component_spec_dict) + component_spec.default_creation_method = "from_pretrained" + self._component_specs[name] = component_spec + + elif name in self._config_specs: + self._config_specs[name].default = value + + register_components_dict = {} + for name, component_spec in self._component_specs.items(): + if component_spec.default_creation_method == "from_config": + component = component_spec.create() + else: + component = None + register_components_dict[name] = component + self.register_components(**register_components_dict) + + default_configs = {} + for name, config_spec in self._config_specs.items(): + default_configs[name] = config_spec.default + self.register_to_config(**default_configs) + + self.register_to_config(_blocks_class_name=self.blocks.__class__.__name__ if self.blocks is not None else None) + + @property + def default_call_parameters(self) -> Dict[str, Any]: + params = {} + for input_param in self.blocks.inputs: + params[input_param.name] = input_param.default + return params + + def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs): + """ + Run one or more blocks in sequence, optionally you can pass a previous pipeline state. + """ + if state is None: + state = PipelineState() + + # Make a copy of the input kwargs + passed_kwargs = kwargs.copy() + + # Add inputs to state, using defaults if not provided in the kwargs or the state + # if same input already in the state, will override it if provided in the kwargs + + intermediate_inputs = [inp.name for inp in self.blocks.intermediate_inputs] + for expected_input_param in self.blocks.inputs: + name = expected_input_param.name + default = expected_input_param.default + kwargs_type = expected_input_param.kwargs_type + if name in passed_kwargs: + if name not in intermediate_inputs: + state.set_input(name, passed_kwargs.pop(name), kwargs_type) + else: + state.set_input(name, passed_kwargs[name], kwargs_type) + elif name not in state.inputs: + state.set_input(name, default, kwargs_type) + + for expected_intermediate_param in self.blocks.intermediate_inputs: + name = expected_intermediate_param.name + kwargs_type = expected_intermediate_param.kwargs_type + if name in passed_kwargs: + state.set_intermediate(name, passed_kwargs.pop(name), kwargs_type) + + # Warn about unexpected inputs + if len(passed_kwargs) > 0: + warnings.warn(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.") + # Run the pipeline + with torch.no_grad(): + try: + _, state = self.blocks(self, state) + except Exception: + error_msg = f"Error in block: ({self.blocks.__class__.__name__}):\n" + logger.error(error_msg) + raise + + if output is None: + return state + + elif isinstance(output, str): + return state.get_intermediate(output) + + elif isinstance(output, (list, tuple)): + return state.get_intermediates(output) + else: + raise ValueError(f"Output '{output}' is not a valid output type") + + def load_default_components(self, **kwargs): + names = [ + name + for name in self._component_specs.keys() + if self._component_specs[name].default_creation_method == "from_pretrained" + ] + self.load_components(names=names, **kwargs) + + @classmethod + @validate_hf_hub_args + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + trust_remote_code: Optional[bool] = None, + components_manager: Optional[ComponentsManager] = None, + collection: Optional[str] = None, + **kwargs, + ): + from ..pipelines.pipeline_loading_utils import _get_pipeline_class + try: + blocks = ModularPipelineBlocks.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs + ) + except EnvironmentError: + blocks = None + + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + token = kwargs.pop("token", None) + local_files_only = kwargs.pop("local_files_only", False) + revision = kwargs.pop("revision", None) + + load_config_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "token": token, + "local_files_only": local_files_only, + "revision": revision, + } + + config_dict = cls.load_config(pretrained_model_name_or_path, **kwargs) + pipeline_class = _get_pipeline_class(cls, config=config_dict) + + pipeline = pipeline_class( + blocks=blocks, + pretrained_model_name_or_path=pretrained_model_name_or_path, + components_manager=components_manager, + collection=collection, + **kwargs + ) + return pipeline + + # YiYi TODO: + # 1. should support save some components too! currently only modular_model_index.json is saved + # 2. maybe order the json file to make it more readable: configs first, then components + def save_pretrained( + self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs + ): + + self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) + + @property + def doc(self): + return self.blocks.doc + + def register_components(self, **kwargs): """ Register components with their corresponding specifications. This method is responsible for: 1. Sets component objects as attributes on the loader (e.g., self.unet = unet) - 2. Updates the modular_model_index.json configuration for serialization - 4. Adds components to the component manager if one is attached + 2. Updates the modular_model_index.json configuration for serialization (only for from_pretrained components) + 3. Adds components to the component manager if one is attached (only for from_pretrained components) This method is called when: - Components are first initialized in __init__: @@ -1675,47 +1868,47 @@ def register_components(self, **kwargs): Notes: - Components must be created from ComponentSpec (have _diffusers_load_id attribute) - - When registering None for a component, it updates the modular_model_index.json config but sets attribute - to None + - When registering None for a component, it sets attribute to None but still syncs specs with the modular_model_index.json config """ for name, module in kwargs.items(): # current component spec component_spec = self._component_specs.get(name) if component_spec is None: - logger.warning(f"ModularLoader.register_components: skipping unknown component '{name}'") + logger.warning(f"ModularPipeline.register_components: skipping unknown component '{name}'") continue # check if it is the first time registration, i.e. calling from __init__ is_registered = hasattr(self, name) + is_from_pretrained = component_spec.default_creation_method == "from_pretrained" # make sure the component is created from ComponentSpec if module is not None and not hasattr(module, "_diffusers_load_id"): - raise ValueError("`ModularLoader` only supports components created from `ComponentSpec`.") - + raise ValueError("`ModularPipeline` only supports components created from `ComponentSpec`.") + if module is not None: # actual library and class name of the module library, class_name = _fetch_class_library_tuple(module) # e.g. ("diffusers", "UNet2DConditionModel") - - # extract the loading spec from the updated component spec that'll be used as part of modular_model_index.json config - # e.g. {"repo": "stabilityai/stable-diffusion-2-1", - # "type_hint": ("diffusers", "UNet2DConditionModel"), - # "subfolder": "unet", - # "variant": None, - # "revision": None} - component_spec_dict = self._component_spec_to_dict(component_spec) - else: - # if module is None, e.g. self.register_components(unet=None) during __init__ + # if module is None, e.g. self.register_components(unet=None) during __init__ # we do not update the spec, - # but we still need to update the modular_model_index.json config based oncomponent spec + # but we still need to update the modular_model_index.json config based on component spec library, class_name = None, None - component_spec_dict = self._component_spec_to_dict(component_spec) + + # extract the loading spec from the updated component spec that'll be used as part of modular_model_index.json config + # e.g. {"repo": "stabilityai/stable-diffusion-2-1", + # "type_hint": ("diffusers", "UNet2DConditionModel"), + # "subfolder": "unet", + # "variant": None, + # "revision": None} + component_spec_dict = self._component_spec_to_dict(component_spec) + register_dict = {name: (library, class_name, component_spec_dict)} # set the component as attribute # if it is not set yet, just set it and skip the process to check and warn below if not is_registered: - self.register_to_config(**register_dict) + if is_from_pretrained: + self.register_to_config(**register_dict) setattr(self, name, module) if module is not None and module._diffusers_load_id != "null" and self._components_manager is not None: self._components_manager.add(name, module, self._collection) @@ -1725,14 +1918,14 @@ def register_components(self, **kwargs): # skip if the component is already registered with the same object if current_module is module: logger.info( - f"ModularLoader.register_components: {name} is already registered with same object, skipping" + f"ModularPipeline.register_components: {name} is already registered with same object, skipping" ) continue # warn if unregister if current_module is not None and module is None: logger.info( - f"ModularLoader.register_components: setting '{name}' to None " + f"ModularPipeline.register_components: setting '{name}' to None " f"(was {current_module.__class__.__name__})" ) # same type, new instance → replace but send debug log @@ -1743,67 +1936,19 @@ def register_components(self, **kwargs): and current_module != module ): logger.debug( - f"ModularLoader.register_components: replacing existing '{name}' " + f"ModularPipeline.register_components: replacing existing '{name}' " f"(same type {type(current_module).__name__}, new instance)" ) # update modular_model_index.json config - self.register_to_config(**register_dict) + if is_from_pretrained: + self.register_to_config(**register_dict) # finally set models setattr(self, name, module) # add to component manager if one is attached if module is not None and module._diffusers_load_id != "null" and self._components_manager is not None: self._components_manager.add(name, module, self._collection) - # YiYi TODO: add warning for passing multiple ComponentSpec/ConfigSpec with the same name - def __init__( - self, - specs: List[Union[ComponentSpec, ConfigSpec]], - pretrained_model_name_or_path: Optional[str] = None, - components_manager: Optional[ComponentsManager] = None, - collection: Optional[str] = None, - **kwargs, - ): - """ - Initialize the loader with a list of component specs and config specs. - """ - self._components_manager = components_manager - self._collection = collection - self._component_specs = {spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ComponentSpec)} - self._config_specs = {spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ConfigSpec)} - - # update component_specs and config_specs from modular_repo - if pretrained_model_name_or_path is not None: - config_dict = self.load_config(pretrained_model_name_or_path, **kwargs) - - for name, value in config_dict.items(): - # only update component_spec for from_pretrained components - if ( - name in self._component_specs - and self._component_specs[name].default_creation_method == "from_pretrained" - and isinstance(value, (tuple, list)) - and len(value) == 3 - ): - library, class_name, component_spec_dict = value - component_spec = self._dict_to_component_spec(name, component_spec_dict) - self._component_specs[name] = component_spec - - elif name in self._config_specs: - self._config_specs[name].default = value - - register_components_dict = {} - for name, component_spec in self._component_specs.items(): - if component_spec.default_creation_method == "from_config": - component = component_spec.create() - else: - component = None - register_components_dict[name] = component - self.register_components(**register_components_dict) - - default_configs = {} - for name, config_spec in self._config_specs.items(): - default_configs[name] = config_spec.default - self.register_to_config(**default_configs) @property def device(self) -> torch.device: @@ -1885,7 +2030,10 @@ def components(self) -> Dict[str, Any]: # return only components we've actually set as attributes on self return {name: getattr(self, name) for name in self._component_specs.keys() if hasattr(self, name)} - def update(self, **kwargs): + def get_component_spec(self, name: str) -> ComponentSpec: + return deepcopy(self._component_specs[name]) + + def update_components(self, **kwargs): """ Update components and configuration values after the loader has been instantiated. @@ -1938,7 +2086,7 @@ def update(self, **kwargs): for name, component in passed_components.items(): if not hasattr(component, "_diffusers_load_id"): - raise ValueError("`ModularLoader` only supports components created from `ComponentSpec`.") + raise ValueError("`ModularPipeline` only supports components created from `ComponentSpec`.") # YiYi TODO: remove this if we remove support for non config mixin components in `create()` method if component._diffusers_load_id == "null" and not isinstance(component, ConfigMixin): @@ -1953,7 +2101,7 @@ def update(self, **kwargs): component, current_component_spec.type_hint ): logger.warning( - f"ModularLoader.update: adding {name} with new type: {component.__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}" + f"ModularPipeline.update: adding {name} with new type: {component.__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}" ) # update _component_specs based on the new component new_component_spec = ComponentSpec.from_component(name, component) @@ -1975,7 +2123,7 @@ def update(self, **kwargs): created_components[name], current_component_spec.type_hint ): logger.warning( - f"ModularLoader.update: adding {name} with new type: {created_components[name].__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}" + f"ModularPipeline.update: adding {name} with new type: {created_components[name].__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}" ) # update _component_specs based on the user passed component_spec self._component_specs[name] = component_spec @@ -1989,7 +2137,7 @@ def update(self, **kwargs): self.register_to_config(**config_to_register) # YiYi TODO: support map for additional from_pretrained kwargs - def load(self, names: Union[List[str], str], **kwargs): + def load_components(self, names: Union[List[str], str], **kwargs): """ Load selected components from specs. @@ -2246,58 +2394,16 @@ def module_is_offloaded(module): ) return self - # YiYi TODO: - # 1. should support save some components too! currently only modular_model_index.json is saved - # 2. maybe order the json file to make it more readable: configs first, then components - def save_pretrained( - self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, spec_only: bool = True, **kwargs - ): - component_names = list(self._component_specs.keys()) - config_names = list(self._config_specs.keys()) - self.register_to_config(_components_names=component_names, _configs_names=config_names) - self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) - config = dict(self.config) - config.pop("_components_names", None) - config.pop("_configs_names", None) - self._internal_dict = FrozenDict(config) - - @classmethod - @validate_hf_hub_args - def from_pretrained( - cls, - pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], - spec_only: bool = True, - components_manager: Optional[ComponentsManager] = None, - collection: Optional[str] = None, - **kwargs, - ): - config_dict = cls.load_config(pretrained_model_name_or_path, **kwargs) - expected_component = set(config_dict.pop("_components_names")) - expected_config = set(config_dict.pop("_configs_names")) - - component_specs = [] - config_specs = [] - for name, value in config_dict.items(): - if name in expected_component and isinstance(value, (tuple, list)) and len(value) == 3: - library, class_name, component_spec_dict = value - # only pick up pretrained components from the repo - if component_spec_dict.get("repo", None) is not None: - component_spec = cls._dict_to_component_spec(name, component_spec_dict) - component_specs.append(component_spec) - - elif name in expected_config: - config_specs.append(ConfigSpec(name=name, default=value)) - - return cls(component_specs + config_specs, components_manager=components_manager, collection=collection) @staticmethod def _component_spec_to_dict(component_spec: ComponentSpec) -> Any: """ Convert a ComponentSpec into a JSON‐serializable dict for saving in `modular_model_index.json`. + If the default_creation_method is not from_pretrained, return None. This dict contains: - "type_hint": Tuple[str, str] - The fully‐qualified module path and class name of the component. + Library name and class name of the component. (e.g. ("diffusers", "UNet2DConditionModel")) - All loading fields defined by `component_spec.loading_fields()`, typically: - "repo": Optional[str] The model repository (e.g., "stabilityai/stable-diffusion-xl"). @@ -2317,23 +2423,36 @@ def _component_spec_to_dict(component_spec: ComponentSpec) -> Any: Dict[str, Any]: A mapping suitable for JSON serialization. Example: - >>> from diffusers.pipelines.modular_pipeline_utils import ComponentSpec >>> from diffusers.models.unet - import UNet2DConditionModel >>> spec = ComponentSpec( ... name="unet", ... type_hint=UNet2DConditionModel, - ... config=None, ... repo="path/to/repo", ... subfolder="subfolder", ... variant=None, ... revision=None, - ... default_creation_method="from_pretrained", ... ) >>> ModularLoader._component_spec_to_dict(spec) { - "type_hint": ("diffusers.models.unet", "UNet2DConditionModel"), "repo": "path/to/repo", "subfolder": - "subfolder", "variant": None, "revision": None, + >>> from diffusers.pipelines.modular_pipeline_utils import ComponentSpec + >>> from diffusers import UNet2DConditionModel + >>> spec = ComponentSpec( + ... name="unet", + ... type_hint=UNet2DConditionModel, + ... config=None, + ... repo="path/to/repo", + ... subfolder="subfolder", + ... variant=None, + ... revision=None, + ... default_creation_method="from_pretrained", + ... ) + >>> ModularPipeline._component_spec_to_dict(spec) + { + "type_hint": ("diffusers", "UNet2DConditionModel"), + "repo": "path/to/repo", + "subfolder": "subfolder", + "variant": None, + "revision": None, } """ + if component_spec.default_creation_method != "from_pretrained": + return None + if component_spec.type_hint is not None: lib_name, cls_name = _fetch_class_library_tuple(component_spec.type_hint) else: lib_name = None cls_name = None - if component_spec.default_creation_method == "from_pretrained": - load_spec_dict = {k: getattr(component_spec, k) for k in component_spec.loading_fields()} - else: - load_spec_dict = {} + load_spec_dict = {k: getattr(component_spec, k) for k in component_spec.loading_fields()} return { "type_hint": (lib_name, cls_name), **load_spec_dict, @@ -2345,7 +2464,51 @@ def _dict_to_component_spec( spec_dict: Dict[str, Any], ) -> ComponentSpec: """ - Reconstruct a ComponentSpec from a dict. + Reconstruct a ComponentSpec from a loading specdict. + + This method converts a dictionary representation back into a ComponentSpec object. + The dict should contain: + - "type_hint": Tuple[str, str] + Library name and class name of the component. (e.g. ("diffusers", "UNet2DConditionModel")) + - All loading fields defined by `component_spec.loading_fields()`, typically: + - "repo": Optional[str] + The model repository (e.g., "stabilityai/stable-diffusion-xl"). + - "subfolder": Optional[str] + A subfolder within the repo where this component lives. + - "variant": Optional[str] + An optional variant identifier for the model. + - "revision": Optional[str] + A specific git revision (commit hash, tag, or branch). + - ... any other loading fields defined on the spec. + + Args: + name (str): + The name of the component. + specdict (Dict[str, Any]): + A dictionary containing the component specification data. + + Returns: + ComponentSpec: A reconstructed ComponentSpec object. + + Example: + >>> spec_dict = { + ... "type_hint": ("diffusers", "UNet2DConditionModel"), + ... "repo": "stabilityai/stable-diffusion-xl", + ... "subfolder": "unet", + ... "variant": None, + ... "revision": None, + ... } + >>> ModularPipeline._dict_to_component_spec("unet", spec_dict) + ComponentSpec( + name="unet", + type_hint=UNet2DConditionModel, + config=None, + repo="stabilityai/stable-diffusion-xl", + subfolder="unet", + variant=None, + revision=None, + default_creation_method="from_pretrained" + ) """ # make a shallow copy so we can pop() safely spec_dict = spec_dict.copy() @@ -2361,133 +2524,4 @@ def _dict_to_component_spec( name=name, type_hint=type_hint, **spec_dict, - ) - - -class ModularPipeline: - """ - Base class for all Modular pipelines. - - Args: - blocks: ModularPipelineBlocks, the blocks to be used in the pipeline - loader: ModularLoader, the loader to be used in the pipeline - """ - - def __init__(self, blocks: ModularPipelineBlocks, loader: ModularLoader): - self.blocks = blocks - self.loader = loader - - def __repr__(self): - return f"ModularPipeline(\n blocks={repr(self.blocks)},\n loader={repr(self.loader)}\n)" - - @property - def default_call_parameters(self) -> Dict[str, Any]: - params = {} - for input_param in self.blocks.inputs: - params[input_param.name] = input_param.default - return params - - def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs): - """ - Run one or more blocks in sequence, optionally you can pass a previous pipeline state. - """ - if state is None: - state = PipelineState() - - # Make a copy of the input kwargs - passed_kwargs = kwargs.copy() - - # Add inputs to state, using defaults if not provided in the kwargs or the state - # if same input already in the state, will override it if provided in the kwargs - - intermediate_inputs = [inp.name for inp in self.blocks.intermediate_inputs] - for expected_input_param in self.blocks.inputs: - name = expected_input_param.name - default = expected_input_param.default - kwargs_type = expected_input_param.kwargs_type - if name in passed_kwargs: - if name not in intermediate_inputs: - state.add_input(name, passed_kwargs.pop(name), kwargs_type) - else: - state.add_input(name, passed_kwargs[name], kwargs_type) - elif name not in state.inputs: - state.add_input(name, default, kwargs_type) - - for expected_intermediate_param in self.blocks.intermediate_inputs: - name = expected_intermediate_param.name - kwargs_type = expected_intermediate_param.kwargs_type - if name in passed_kwargs: - state.add_intermediate(name, passed_kwargs.pop(name), kwargs_type) - - # Warn about unexpected inputs - if len(passed_kwargs) > 0: - warnings.warn(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.") - # Run the pipeline - with torch.no_grad(): - try: - pipeline, state = self.blocks(self.loader, state) - except Exception: - error_msg = f"Error in block: ({self.blocks.__class__.__name__}):\n" - logger.error(error_msg) - raise - - if output is None: - return state - - elif isinstance(output, str): - return state.get_intermediate(output) - - elif isinstance(output, (list, tuple)): - return state.get_intermediates(output) - else: - raise ValueError(f"Output '{output}' is not a valid output type") - - def load_default_components(self, **kwargs): - names = [ - name - for name in self.loader._component_specs.keys() - if self.loader._component_specs[name].default_creation_method == "from_pretrained" - ] - self.loader.load(names=names, **kwargs) - - def load_components(self, names: Union[List[str], str], **kwargs): - self.loader.load(names=names, **kwargs) - - def update_components(self, **kwargs): - self.loader.update(**kwargs) - - @classmethod - @validate_hf_hub_args - def from_pretrained( - cls, - pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], - trust_remote_code: Optional[bool] = None, - components_manager: Optional[ComponentsManager] = None, - collection: Optional[str] = None, - **kwargs, - ): - blocks = ModularPipelineBlocks.from_pretrained( - pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs - ) - pipeline = blocks.init_pipeline( - pretrained_model_name_or_path, components_manager=components_manager, collection=collection, **kwargs - ) - return pipeline - - def save_pretrained( - self, save_directory: Optional[Union[str, os.PathLike]] = None, push_to_hub: bool = False, **kwargs - ): - self.blocks.save_pretrained(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) - self.loader.save_pretrained(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) - - @property - def doc(self): - return self.blocks.doc - - def to(self, *args, **kwargs): - self.loader.to(*args, **kwargs) - return self - - @property - def components(self): - return self.loader.components + ) \ No newline at end of file diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index 37696f5dfac6..90f4586753d4 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -191,7 +191,7 @@ def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]: # YiYi TODO: I think we should only support ConfigMixin for this method (after we make guider and image_processors config mixin) # otherwise we cannot do spec -> spec.create() -> component -> ComponentSpec.from_component(component) # the config info is lost in the process - # remove error check in from_component spec and ModularLoader.update() if we remove support for non configmixin in `create()` method + # remove error check in from_component spec and ModularPipeline.update_components() if we remove support for non configmixin in `create()` method def create(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any: """Create component using from_config with config.""" diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py index 95461cfc23c9..59ec46dc6d36 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py @@ -22,7 +22,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["encoders"] = ["StableDiffusionXLTextEncoderStep"] - _import_structure["modular_blocks_presets"] = [ + _import_structure["modular_blocks"] = [ "ALL_BLOCKS", "AUTO_BLOCKS", "CONTROLNET_BLOCKS", @@ -36,7 +36,7 @@ "StableDiffusionXLAutoIPAdapterStep", "StableDiffusionXLAutoVaeEncoderStep", ] - _import_structure["modular_loader"] = ["StableDiffusionXLModularLoader"] + _import_structure["modular_pipeline"] = ["StableDiffusionXLModularPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -48,7 +48,7 @@ from .encoders import ( StableDiffusionXLTextEncoderStep, ) - from .modular_blocks_presets import ( + from .modular_blocks import ( ALL_BLOCKS, AUTO_BLOCKS, CONTROLNET_BLOCKS, @@ -62,7 +62,7 @@ StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep, ) - from .modular_loader import StableDiffusionXLModularLoader + from .modular_pipeline import StableDiffusionXLModularPipeline else: import sys diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py index 04da975aec8f..b064a74cbfa0 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py @@ -30,7 +30,7 @@ PipelineState, ) from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam -from .modular_loader import StableDiffusionXLModularLoader +from .modular_pipeline import StableDiffusionXLModularPipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -338,7 +338,7 @@ def check_inputs(self, components, block_state): ) @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) self.check_inputs(components, block_state) @@ -388,7 +388,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt [negative_ip_adapter_embed] * block_state.num_images_per_prompt, dim=0 ) - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state @@ -491,7 +491,7 @@ def get_timesteps(components, num_inference_steps, strength, device, denoising_s return timesteps, num_inference_steps @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) block_state.device = components._execution_device @@ -537,7 +537,7 @@ def denoising_value_valid(dnv): ) block_state.timesteps = block_state.timesteps[: block_state.num_inference_steps] - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state @@ -576,7 +576,7 @@ def intermediate_outputs(self) -> List[OutputParam]: ] @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) block_state.device = components._execution_device @@ -606,7 +606,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt ) block_state.timesteps = block_state.timesteps[: block_state.num_inference_steps] - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state @@ -851,7 +851,7 @@ def prepare_mask_latents( return mask, masked_image_latents @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype @@ -900,7 +900,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt block_state.generator, ) - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state @@ -961,7 +961,7 @@ def intermediate_outputs(self) -> List[OutputParam]: ] @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype @@ -981,7 +981,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt block_state.add_noise, ) - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state @@ -1066,7 +1066,7 @@ def prepare_latents(comp, batch_size, num_channels_latents, height, width, dtype return latents @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) if block_state.dtype is None: @@ -1091,7 +1091,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt block_state.latents, ) - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state @@ -1249,7 +1249,7 @@ def get_guidance_scale_embedding( return emb @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) block_state.device = components._execution_device @@ -1304,7 +1304,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim ).to(device=block_state.device, dtype=block_state.latents.dtype) - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state @@ -1420,7 +1420,7 @@ def get_guidance_scale_embedding( return emb @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) block_state.device = components._execution_device @@ -1475,7 +1475,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim ).to(device=block_state.device, dtype=block_state.latents.dtype) - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state @@ -1590,7 +1590,7 @@ def prepare_control_image( return image @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) # (1) prepare controlnet inputs @@ -1693,7 +1693,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt block_state.controlnet_cond = block_state.control_image block_state.conditioning_scale = block_state.controlnet_conditioning_scale - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state @@ -1824,7 +1824,7 @@ def prepare_control_image( return image @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) controlnet = unwrap_module(components.controlnet) @@ -1904,6 +1904,6 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt block_state.controlnet_cond = block_state.control_image block_state.conditioning_scale = block_state.controlnet_conditioning_scale - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py index 92b84b8595e4..878e991dbf63 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py @@ -152,7 +152,7 @@ def __call__(self, components, state: PipelineState) -> PipelineState: block_state.images, output_type=block_state.output_type ) - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state @@ -212,6 +212,6 @@ def __call__(self, components, state: PipelineState) -> PipelineState: for i in block_state.images ] - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py index fd61c235c26b..7fe4a472eec3 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py @@ -29,7 +29,7 @@ PipelineState, ) from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam -from .modular_loader import StableDiffusionXLModularLoader +from .modular_pipeline import StableDiffusionXLModularPipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -66,7 +66,7 @@ def intermediate_inputs(self) -> List[str]: ] @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int): block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) return components, block_state @@ -131,7 +131,7 @@ def check_inputs(components, block_state): ) @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int): self.check_inputs(components, block_state) block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) @@ -202,7 +202,7 @@ def intermediate_inputs(self) -> List[str]: @torch.no_grad() def __call__( - self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int + self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int ) -> PipelineState: # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) @@ -347,7 +347,7 @@ def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): return extra_kwargs @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int): extra_controlnet_kwargs = self.prepare_extra_kwargs( components.controlnet.forward, **block_state.controlnet_kwargs ) @@ -494,7 +494,7 @@ def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): return extra_kwargs @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int): # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline block_state.extra_step_kwargs = self.prepare_extra_kwargs( components.scheduler.step, generator=block_state.generator, eta=block_state.eta @@ -595,7 +595,7 @@ def check_inputs(self, components, block_state): raise ValueError(f"noise is required for this step {self.__class__.__name__}") @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int): self.check_inputs(components, block_state) # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline @@ -677,7 +677,7 @@ def loop_intermediate_inputs(self) -> List[InputParam]: ] @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False @@ -698,7 +698,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt ): progress_bar.update() - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py index b4526537d7d4..bd0e962140e8 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py @@ -37,7 +37,7 @@ ) from ..modular_pipeline import PipelineBlock, PipelineState from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam -from .modular_loader import StableDiffusionXLModularLoader +from .modular_pipeline import StableDiffusionXLModularPipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -65,8 +65,7 @@ def description(self) -> str: return ( "IP Adapter step that prepares ip adapter image embeddings.\n" "Note that this step only prepares the embeddings - in order for it to work correctly, " - "you need to load ip adapter weights into unet via ModularPipeline.loader.\n" - "e.g. pipeline.loader.load_ip_adapter() and pipeline.loader.set_ip_adapter_scale().\n" + "you need to load ip adapter weights into unet via ModularPipeline.load_ip_adapter() and pipeline.set_ip_adapter_scale().\n" "See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)" " for more details" ) @@ -191,7 +190,7 @@ def prepare_ip_adapter_image_embeds( return ip_adapter_image_embeds @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 @@ -212,7 +211,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt block_state.negative_ip_adapter_embeds.append(negative_image_embeds) block_state.ip_adapter_embeds[i] = image_embeds - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state @@ -537,7 +536,7 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: # Get inputs and intermediates block_state = self.get_block_state(state) self.check_inputs(block_state) @@ -573,7 +572,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt clip_skip=block_state.clip_skip, ) # Add outputs - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state @@ -663,7 +662,7 @@ def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Ge return image_latents @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) block_state.preprocess_kwargs = block_state.preprocess_kwargs or {} block_state.device = components._execution_device @@ -687,7 +686,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt components, image=block_state.image, generator=block_state.generator ) - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state @@ -841,7 +840,7 @@ def prepare_mask_latents( return mask, masked_image_latents @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype @@ -898,6 +897,6 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt block_state.generator, ) - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks_presets.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py similarity index 100% rename from src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks_presets.py rename to src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py similarity index 99% rename from src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py rename to src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py index 82c4d6de0fe4..90850ea53606 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py @@ -23,7 +23,7 @@ from ...pipelines.pipeline_utils import StableDiffusionMixin from ...pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput from ...utils import logging -from ..modular_pipeline import ModularLoader +from ..modular_pipeline import ModularPipeline from ..modular_pipeline_utils import InputParam, OutputParam @@ -32,13 +32,13 @@ # YiYi TODO: move to a different file? stable_diffusion_xl_module should have its own folder? # YiYi Notes: model specific components: -## (1) it should inherit from ModularLoader +## (1) it should inherit from ModularPipeline ## (2) acts like a container that holds components and configs ## (3) define default config (related to components), e.g. default_sample_size, vae_scale_factor, num_channels_unet, num_channels_latents ## (4) inherit from model-specic loader class (e.g. StableDiffusionXLLoraLoaderMixin) ## (5) how to use together with Components_manager? -class StableDiffusionXLModularLoader( - ModularLoader, +class StableDiffusionXLModularPipeline( + ModularPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 7e48cca09393..b5ac6cc3012f 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -468,7 +468,7 @@ def _get_pipeline_class( revision=revision, ) - if class_obj.__name__ != "DiffusionPipeline": + if class_obj.__name__ != "DiffusionPipeline" and class_obj.__name__ != "ModularPipeline": return class_obj diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0]) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 496039a436e5..b192b58531ac 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1349,21 +1349,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class ModularLoader(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - class ModularPipeline(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index a9daf50a7a7c..62f173569520 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -17,7 +17,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class StableDiffusionXLModularLoader(metaclass=DummyObject): +class StableDiffusionXLModularPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): From 4f8b6f5a150c61aa227665ccf2041402b6eb2d1f Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 6 Jul 2025 03:23:31 +0200 Subject: [PATCH 146/170] style + copy --- .../modular_pipelines/components_manager.py | 2 +- .../modular_pipelines/modular_pipeline.py | 103 ++++++------------ src/diffusers/utils/dummy_pt_objects.py | 15 +++ 3 files changed, 51 insertions(+), 69 deletions(-) diff --git a/src/diffusers/modular_pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py index cf6501ad2799..7f1c205d3ec5 100644 --- a/src/diffusers/modular_pipelines/components_manager.py +++ b/src/diffusers/modular_pipelines/components_manager.py @@ -150,7 +150,7 @@ class AutoOffloadStrategy: the available memory on the device. """ - # YiYi TODO: instead of memory_reserve_margin, we should let user set the maximum_total_models_size to keep on device + # YiYi TODO: instead of memory_reserve_margin, we should let user set the maximum_total_models_size to keep on device # the actual memory usage would be higher. But it's simpler this way, and can be tested def __init__(self, memory_reserve_margin="3GB"): self.memory_reserve_margin = convert_file_size_to_int(memory_reserve_margin) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index d0429a1f45bf..7d640fc25d2a 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -19,7 +19,6 @@ from collections import OrderedDict from copy import deepcopy from dataclasses import dataclass, field -from types import SimpleNamespace from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -343,7 +342,7 @@ def init_pipeline( pipeline_class = getattr(diffusers_module, pipeline_class_name) modular_pipeline = pipeline_class( - blocks=deepcopy(self), + blocks=deepcopy(self), pretrained_model_name_or_path=pretrained_model_name_or_path, components_manager=components_manager, collection=collection, @@ -1686,11 +1685,7 @@ def __init__( for name, value in config_dict.items(): # all the components in modular_model_index.json are from_pretrained components - if ( - name in self._component_specs - and isinstance(value, (tuple, list)) - and len(value) == 3 - ): + if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 3: library, class_name, component_spec_dict = value component_spec = self._dict_to_component_spec(name, component_spec_dict) component_spec.default_creation_method = "from_pretrained" @@ -1794,15 +1789,16 @@ def from_pretrained( components_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs, - ): + ): from ..pipelines.pipeline_loading_utils import _get_pipeline_class + try: blocks = ModularPipelineBlocks.from_pretrained( pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs ) except EnvironmentError: blocks = None - + cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) @@ -1818,33 +1814,29 @@ def from_pretrained( "local_files_only": local_files_only, "revision": revision, } - - config_dict = cls.load_config(pretrained_model_name_or_path, **kwargs) + + config_dict = cls.load_config(pretrained_model_name_or_path, **load_config_kwargs) pipeline_class = _get_pipeline_class(cls, config=config_dict) pipeline = pipeline_class( - blocks=blocks, - pretrained_model_name_or_path=pretrained_model_name_or_path, - components_manager=components_manager, - collection=collection, - **kwargs + blocks=blocks, + pretrained_model_name_or_path=pretrained_model_name_or_path, + components_manager=components_manager, + collection=collection, + **kwargs, ) return pipeline # YiYi TODO: # 1. should support save some components too! currently only modular_model_index.json is saved # 2. maybe order the json file to make it more readable: configs first, then components - def save_pretrained( - self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs - ): - + def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) @property def doc(self): return self.blocks.doc - def register_components(self, **kwargs): """ Register components with their corresponding specifications. @@ -1868,7 +1860,8 @@ def register_components(self, **kwargs): Notes: - Components must be created from ComponentSpec (have _diffusers_load_id attribute) - - When registering None for a component, it sets attribute to None but still syncs specs with the modular_model_index.json config + - When registering None for a component, it sets attribute to None but still syncs specs with the + modular_model_index.json config """ for name, module in kwargs.items(): # current component spec @@ -1884,12 +1877,12 @@ def register_components(self, **kwargs): # make sure the component is created from ComponentSpec if module is not None and not hasattr(module, "_diffusers_load_id"): raise ValueError("`ModularPipeline` only supports components created from `ComponentSpec`.") - + if module is not None: # actual library and class name of the module library, class_name = _fetch_class_library_tuple(module) # e.g. ("diffusers", "UNet2DConditionModel") else: - # if module is None, e.g. self.register_components(unet=None) during __init__ + # if module is None, e.g. self.register_components(unet=None) during __init__ # we do not update the spec, # but we still need to update the modular_model_index.json config based on component spec library, class_name = None, None @@ -1949,7 +1942,6 @@ def register_components(self, **kwargs): if module is not None and module._diffusers_load_id != "null" and self._components_manager is not None: self._components_manager.add(name, module, self._collection) - @property def device(self) -> torch.device: r""" @@ -2394,12 +2386,11 @@ def module_is_offloaded(module): ) return self - @staticmethod def _component_spec_to_dict(component_spec: ComponentSpec) -> Any: """ - Convert a ComponentSpec into a JSON‐serializable dict for saving in `modular_model_index.json`. - If the default_creation_method is not from_pretrained, return None. + Convert a ComponentSpec into a JSON‐serializable dict for saving in `modular_model_index.json`. If the + default_creation_method is not from_pretrained, return None. This dict contains: - "type_hint": Tuple[str, str] @@ -2423,30 +2414,19 @@ def _component_spec_to_dict(component_spec: ComponentSpec) -> Any: Dict[str, Any]: A mapping suitable for JSON serialization. Example: - >>> from diffusers.pipelines.modular_pipeline_utils import ComponentSpec - >>> from diffusers import UNet2DConditionModel - >>> spec = ComponentSpec( - ... name="unet", - ... type_hint=UNet2DConditionModel, - ... config=None, - ... repo="path/to/repo", - ... subfolder="subfolder", - ... variant=None, - ... revision=None, - ... default_creation_method="from_pretrained", - ... ) - >>> ModularPipeline._component_spec_to_dict(spec) - { - "type_hint": ("diffusers", "UNet2DConditionModel"), - "repo": "path/to/repo", - "subfolder": "subfolder", - "variant": None, - "revision": None, + >>> from diffusers.pipelines.modular_pipeline_utils import ComponentSpec >>> from diffusers import + UNet2DConditionModel >>> spec = ComponentSpec( + ... name="unet", ... type_hint=UNet2DConditionModel, ... config=None, ... repo="path/to/repo", ... + subfolder="subfolder", ... variant=None, ... revision=None, ... + default_creation_method="from_pretrained", + ... ) >>> ModularPipeline._component_spec_to_dict(spec) { + "type_hint": ("diffusers", "UNet2DConditionModel"), "repo": "path/to/repo", "subfolder": "subfolder", + "variant": None, "revision": None, } """ if component_spec.default_creation_method != "from_pretrained": return None - + if component_spec.type_hint is not None: lib_name, cls_name = _fetch_class_library_tuple(component_spec.type_hint) else: @@ -2466,8 +2446,7 @@ def _dict_to_component_spec( """ Reconstruct a ComponentSpec from a loading specdict. - This method converts a dictionary representation back into a ComponentSpec object. - The dict should contain: + This method converts a dictionary representation back into a ComponentSpec object. The dict should contain: - "type_hint": Tuple[str, str] Library name and class name of the component. (e.g. ("diffusers", "UNet2DConditionModel")) - All loading fields defined by `component_spec.loading_fields()`, typically: @@ -2491,23 +2470,11 @@ def _dict_to_component_spec( ComponentSpec: A reconstructed ComponentSpec object. Example: - >>> spec_dict = { - ... "type_hint": ("diffusers", "UNet2DConditionModel"), - ... "repo": "stabilityai/stable-diffusion-xl", - ... "subfolder": "unet", - ... "variant": None, - ... "revision": None, - ... } - >>> ModularPipeline._dict_to_component_spec("unet", spec_dict) - ComponentSpec( - name="unet", - type_hint=UNet2DConditionModel, - config=None, - repo="stabilityai/stable-diffusion-xl", - subfolder="unet", - variant=None, - revision=None, - default_creation_method="from_pretrained" + >>> spec_dict = { ... "type_hint": ("diffusers", "UNet2DConditionModel"), ... "repo": + "stabilityai/stable-diffusion-xl", ... "subfolder": "unet", ... "variant": None, ... "revision": None, ... + } >>> ModularPipeline._dict_to_component_spec("unet", spec_dict) ComponentSpec( + name="unet", type_hint=UNet2DConditionModel, config=None, repo="stabilityai/stable-diffusion-xl", + subfolder="unet", variant=None, revision=None, default_creation_method="from_pretrained" ) """ # make a shallow copy so we can pop() safely @@ -2524,4 +2491,4 @@ def _dict_to_component_spec( name=name, type_hint=type_hint, **spec_dict, - ) \ No newline at end of file + ) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index b192b58531ac..ea1999da1853 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -62,6 +62,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class PerturbedAttentionGuidance(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class SkipLayerGuidance(metaclass=DummyObject): _backends = ["torch"] From 23de59e21a202acc2ba39fd632da0f3b8ee0d8c6 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 6 Jul 2025 06:18:34 +0200 Subject: [PATCH 147/170] add sub_blocks for pipelineBlock --- src/diffusers/modular_pipelines/components_manager.py | 1 + src/diffusers/modular_pipelines/modular_pipeline.py | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/src/diffusers/modular_pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py index 7f1c205d3ec5..48f37788ad21 100644 --- a/src/diffusers/modular_pipelines/components_manager.py +++ b/src/diffusers/modular_pipelines/components_manager.py @@ -229,6 +229,7 @@ class ComponentsManager: def __init__(self): self.components = OrderedDict() + # YiYi TODO: can remove once confirm we don't need this in mellon self.added_time = OrderedDict() # Store when components were added self.collections = OrderedDict() # collection_name -> set of component_names self.model_hooks = None diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 7d640fc25d2a..f23c5644f828 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -353,6 +353,9 @@ def init_pipeline( class PipelineBlock(ModularPipelineBlocks): model_name = None + def __init__(self): + self.sub_blocks = InsertableDict() + @property def description(self) -> str: """Description of the block. Must be implemented by subclasses.""" @@ -2129,6 +2132,7 @@ def update_components(self, **kwargs): self.register_to_config(**config_to_register) # YiYi TODO: support map for additional from_pretrained kwargs + # YiYi/Dhruv TODO: consolidate load_components and load_default_components? def load_components(self, names: Union[List[str], str], **kwargs): """ Load selected components from specs. From 7cea9a3bb0d7d9361cac741e8d29d83b0e501e0e Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 7 Jul 2025 09:48:28 +0200 Subject: [PATCH 148/170] add a guider section on doc --- .../en/modular_diffusers/getting_started.md | 241 +++++++++++++++++- 1 file changed, 236 insertions(+), 5 deletions(-) diff --git a/docs/source/en/modular_diffusers/getting_started.md b/docs/source/en/modular_diffusers/getting_started.md index ff1633988e5f..8b2cfa471525 100644 --- a/docs/source/en/modular_diffusers/getting_started.md +++ b/docs/source/en/modular_diffusers/getting_started.md @@ -605,9 +605,12 @@ When you call `pipeline.load_components(...)`/`pipeline.load_default_components( ### Updating components in a `ModularPipeline` -Similar to `DiffusionPipeline`, You could load an components separately to replace the default one in the pipeline. But in Modular Diffusers system, you need to use `ComponentSpec` to load/create them. +Similar to `DiffusionPipeline`, you can load components separately to replace the default ones in the pipeline. In Modular Diffusers, the approach depends on the component type: -`ComponentSpec` defines how to create or load components and can actually create them using its `create()` method (for ConfigMixin objects) or `load()` method (wrapper around `from_pretrained()`). When a component is loaded with a ComponentSpec, it gets tagged with a unique ID that encodes its creation parameters, allowing you to always extract the original specification using `ComponentSpec.from_component()`. In Modular Diffusers, all pretrained models should be loaded using `ComponentSpec` objects. +- **Pretrained components** (`default_creation_method='from_pretrained'`): Must use `ComponentSpec` to load them, as they get tagged with a unique ID that encodes their loading parameters +- **Config components** (`default_creation_method='from_config'`): These are components that don't need loading specs - they're created during pipeline initialization with default config. To update them, you can either pass the object directly or pass a ComponentSpec directly (which will call `create()` under the hood). + +`ComponentSpec` defines how to create or load components and can actually create them using its `create()` method (for ConfigMixin objects) or `load()` method (wrapper around `from_pretrained()`). When a component is loaded with a ComponentSpec, it gets tagged with a unique ID that encodes its creation parameters, allowing you to always extract the original specification using `ComponentSpec.from_component()`. So instead of @@ -642,8 +645,8 @@ t2i_pipeline.update_components(unet=unet2) Not only is the `unet` component swapped, but its loading specs are also updated from "RunDiffusion/Juggernaut-XL-v9" to "stabilityai/stable-diffusion-xl-base-1.0". This means that if you save the pipeline now and load it back with `from_pretrained`, the new pipeline will by default load the SDXL original unet. ``` ->>> t2i_pipeline.loader -StableDiffusionXLModularLoader { +>>> t2i_pipeline +StableDiffusionXLModularPipeline { ... "unet": [ "diffusers", @@ -682,11 +685,239 @@ ComponentSpec( >>> unet_spec.repo = "stabilityai/stable-diffusion-xl-base-1.0" # Load the component with the modified spec ->>> unet = unet_spec.load() +>>> unet = unet_spec.load(torch_dtype=torch.float16) ``` +### Customizing Guidance Techniques + +Guiders are guidance techniques that can be applied during the denoising process to improve generation quality, control, and adherence to prompts. They work by modifying the noise predictions or model behavior to steer the generation process in desired directions. In diffusers, guiders are implemented as subclasses of `BaseGuidance` and can be easily integrated into modular pipelines, providing a flexible way to enhance generation quality without modifying the underlying diffusion models. + +**ClassifierFreeGuidance (CFG)** is the first and most common guidance technique, used in all our standard pipelines. But we offer many more guidance techniques beyond CFG, including **PerturbedAttentionGuidance (PAG)**, **SkipLayerGuidance (SLG)**, **SmoothedEnergyGuidance (SEG)**, and others that can provide even better results for specific use cases. + +This section demonstrates how to use guiders using the component updating methods we just learned. Since `BaseGuidance` components are stateless (similar to schedulers), they are typically created with default configurations during pipeline initialization using `default_creation_method='from_config'`. This means they don't require loading specs from the repository - you won't see guider listed in `modular_model_index.json` files. + +Let's take a look at the default guider configuration: + +```py +>>> t2i_pipeline.get_component_spec("guider") +ComponentSpec(name='guider', type_hint=, description=None, config=FrozenDict([('guidance_scale', 7.5), ('guidance_rescale', 0.0), ('use_original_formulation', False), ('start', 0.0), ('stop', 1.0), ('_use_default_values', ['start', 'guidance_rescale', 'stop', 'use_original_formulation'])]), repo=None, subfolder=None, variant=None, revision=None, default_creation_method='from_config') +``` + +As you can see, the guider is configured to use `ClassifierFreeGuidance` with default parameters and `default_creation_method='from_config'`, meaning it's created during pipeline initialization rather than loaded from a repository. Let's verify this, here we run `init_pipeline()` without a modular repo, and there it is, a guider with the default configuration we just saw + + +```py +>>> pipeline = t2i_blocks.init_pipeline() +>>> pipeline.guider +ClassifierFreeGuidance { + "_class_name": "ClassifierFreeGuidance", + "_diffusers_version": "0.35.0.dev0", + "guidance_rescale": 0.0, + "guidance_scale": 7.5, + "start": 0.0, + "stop": 1.0, + "use_original_formulation": false +} +``` + +#### Modify Parameters of the Same Guider Type + +To change parameters of the same guider type (e.g., adjusting the `guidance_scale` for CFG), you have two options: + +**Option 1: Use ComponentSpec.create() method** +```python +>>> guider_spec = t2i_pipeline.get_component_spec("guider") +>>> guider = guider_spec.create(guidance_scale=10) +>>> t2i_pipeline.update_components(guider=guider) +``` + +**Option 2: Pass ComponentSpec directly** +```python +>>> guider_spec = t2i_pipeline.get_component_spec("guider") +>>> guider_spec.config["guidance_scale"] = 10 +>>> t2i_pipeline.update_components(guider=guider_spec) +``` + +Both approaches produce the same result: +```python +>>> t2i_pipeline.guider +ClassifierFreeGuidance { + "_class_name": "ClassifierFreeGuidance", + "_diffusers_version": "0.35.0.dev0", + "guidance_rescale": 0.0, + "guidance_scale": 10, + "start": 0.0, + "stop": 1.0, + "use_original_formulation": false +} +``` + +#### Switch to a Different Guider Type + +Since guiders are `from_config` components (ConfigMixin objects), you can pass guider objects directly to switch between different guidance techniques: + +```py +from diffusers import LayerSkipConfig, PerturbedAttentionGuidance +config = LayerSkipConfig(indices=[2, 9], fqn="mid_block.attentions.0.transformer_blocks", skip_attention=False, skip_attention_scores=True, skip_ff=False) +guider = PerturbedAttentionGuidance( + guidance_scale=5.0, perturbed_guidance_scale=2.5, perturbed_guidance_config=config +) +t2i_pipeline.update_components(guider=guider) +``` + +Note that you will get a warning about changing the guider type, which is expected: + +``` +ModularPipeline.update_components: adding guider with new type: PerturbedAttentionGuidance, previous type: ClassifierFreeGuidance +``` + + + +💡 **Component Loading Methods**: +- For `from_config` components (like guiders, schedulers): You can pass the object directly OR pass a ComponentSpec directly (which calls `create()` under the hood) +- For `from_pretrained` components (like models): You must use ComponentSpec to ensure proper tagging and loading + + + +Let's verify that the guider has been updated: + +```py +>>> t2i_pipeline.guider +PerturbedAttentionGuidance { + "_class_name": "PerturbedAttentionGuidance", + "_diffusers_version": "0.35.0.dev0", + "guidance_rescale": 0.0, + "guidance_scale": 5.0, + "perturbed_guidance_config": { + "dropout": 1.0, + "fqn": "mid_block.attentions.0.transformer_blocks", + "indices": [ + 2, + 9 + ], + "skip_attention": false, + "skip_attention_scores": true, + "skip_ff": false + }, + "perturbed_guidance_layers": null, + "perturbed_guidance_scale": 2.5, + "perturbed_guidance_start": 0.01, + "perturbed_guidance_stop": 0.2, + "skip_layer_config": [ + { + "dropout": 1.0, + "fqn": "mid_block.attentions.0.transformer_blocks", + "indices": [ + 2, + 9 + ], + "skip_attention": false, + "skip_attention_scores": true, + "skip_ff": false + } + ], + "skip_layer_guidance_layers": null, + "skip_layer_guidance_scale": 2.5, + "skip_layer_guidance_start": 0.01, + "skip_layer_guidance_stop": 0.2, + "start": 0.0, + "stop": 1.0, + "use_original_formulation": false +} + +``` + +The component spec has also been updated to reflect the new guider type: + +```py +>>> t2i_pipeline.get_component_spec("guider") +ComponentSpec(name='guider', type_hint=, description=None, config=FrozenDict([('guidance_scale', 5.0), ('perturbed_guidance_scale', 2.5), ('perturbed_guidance_start', 0.01), ('perturbed_guidance_stop', 0.2), ('perturbed_guidance_layers', None), ('perturbed_guidance_config', LayerSkipConfig(indices=[2, 9], fqn='mid_block.attentions.0.transformer_blocks', skip_attention=False, skip_attention_scores=True, skip_ff=False, dropout=1.0)), ('guidance_rescale', 0.0), ('use_original_formulation', False), ('start', 0.0), ('stop', 1.0), ('_use_default_values', ['use_original_formulation', 'perturbed_guidance_stop', 'stop', 'guidance_rescale', 'start', 'perturbed_guidance_layers', 'perturbed_guidance_start']), ('skip_layer_guidance_scale', 2.5), ('skip_layer_guidance_start', 0.01), ('skip_layer_guidance_stop', 0.2), ('skip_layer_guidance_layers', None), ('skip_layer_config', [LayerSkipConfig(indices=[2, 9], fqn='mid_block.attentions.0.transformer_blocks', skip_attention=False, skip_attention_scores=True, skip_ff=False, dropout=1.0)]), ('_class_name', 'PerturbedAttentionGuidance'), ('_diffusers_version', '0.35.0.dev0')]), repo=None, subfolder=None, variant=None, revision=None, default_creation_method='from_config') +``` + +However, the "guider" is still not included in the pipeline config and will not be saved into the `modular_model_index.json` since it remains a `from_config` component: + +```py +>>> assert "guider" not in t2i_pipeline.config +``` + +#### Upload Custom Guider to Hub for Easy Loading & Sharing + +You can upload your customized guider to the Hub so that it can be loaded more easily: + +```py +guider.push_to_hub("YiYiXu/modular-loader-t2i-guider", subfolder="pag_guider") +``` + +Voilà! Now you have a subfolder called `pag_guider` on that repository. Let's change our guider_spec to use `from_pretrained` as the default creation method and update the loading spec to use this subfolder we just created: + +```python +guider_spec = t2i_pipeline.get_component_spec("guider") +guider_spec.default_creation_method="from_pretrained" +guider_spec.repo="YiYiXu/modular-loader-t2i-guider" +guider_spec.subfolder="pag_guider" +pag_guider = guider_spec.load() +t2i_pipeline.update_components(guider=pag_guider) +``` + +You will get a warning about changing the creation method: + +``` +ModularPipeline.update_components: changing the default_creation_method of guider from from_config to from_pretrained. +``` + +Now not only the `guider` component and its component_spec are updated, but so is the pipeline config. Let's push it to a new repository: + +```py +t2i_pipeline.push_to_hub("YiYiXu/modular-doc-guider") +``` + +If you check the `modular_model_index.json`, you'll see the guider is now included: + +```json +{ + "guider": [ + "diffusers", + "PerturbedAttentionGuidance", + { + "repo": "YiYiXu/modular-loader-t2i-guider", + "revision": null, + "subfolder": "pag_guider", + "type_hint": [ + "diffusers", + "PerturbedAttentionGuidance" + ], + "variant": null + } + ] +} +``` + +Now when you create the pipeline from that repo directly, the `guider` is not automatically loaded anymore (since it's now a `from_pretrained` component), but when you run `load_default_components()`, the PAG guider will be loaded by default: + +```py +t2i_pipeline = t2i_blocks.init_pipeline("YiYiXu/modular-doc-guider") +assert t2i_pipeline.guider is None +t2i_pipeline.load_default_components() +t2i_pipeline.guider +``` + +Of course, you can also directly modify the `modular_model_index.json` to add a loading spec for the guider by pointing to a folder containing the desired guider config. + + + + +💡 **Guidance Techniques Summary**: +- **ClassifierFreeGuidance (CFG)**: The standard choice, best for general use and prompt adherence +- **PerturbedAttentionGuidance (PAG)**: Enhances attention-based features by perturbing attention mechanisms +- **SkipLayerGuidance (SLG)**: Improves structure and anatomy coherence by skipping specific layers +- **SmoothedEnergyGuidance (SEG)**: Helps with energy distribution smoothing +- **AdaptiveProjectedGuidance (APG)**: Adaptive guidance that projects predictions for better quality + +Experiment with different techniques and parameters to find what works best for your specific use case! + + ### Running a `ModularPipeline` From 0a4819a75575bf0d183b544bf92a27eade1868aa Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 7 Jul 2025 09:49:29 +0200 Subject: [PATCH 149/170] add sub_folder to save_pretrained() for config mixin --- src/diffusers/configuration_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 770c949ffb3d..048ddcae32f9 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -176,6 +176,7 @@ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool token = kwargs.pop("token", None) repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id + subfolder = kwargs.pop("subfolder", None) self._upload_folder( save_directory, @@ -183,6 +184,7 @@ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool token=token, commit_message=commit_message, create_pr=create_pr, + subfolder=subfolder, ) @classmethod From 229c4b355cb5e8fa2d6f7fc9b0b3275c62ac94c0 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 7 Jul 2025 09:50:04 +0200 Subject: [PATCH 150/170] add from_pretrained/save_pretrained for guider --- src/diffusers/guiders/guider_utils.py | 92 ++++++++++++++++++++++++++- 1 file changed, 89 insertions(+), 3 deletions(-) diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index 555f8897c089..22bc8bae172e 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -12,12 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union, Optional import torch +from huggingface_hub.utils import validate_hf_hub_args +from typing_extensions import Self + +import os from ..configuration_utils import ConfigMixin -from ..utils import get_logger +from ..utils import PushToHubMixin, get_logger + if TYPE_CHECKING: @@ -30,7 +35,7 @@ logger = get_logger(__name__) # pylint: disable=invalid-name -class BaseGuidance(ConfigMixin): +class BaseGuidance(ConfigMixin, PushToHubMixin): r"""Base class providing the skeleton for implementing guidance techniques.""" config_name = GUIDER_CONFIG_NAME @@ -198,6 +203,87 @@ def _prepare_batch( data_batch[cls._identifier_key] = identifier return BlockState(**data_batch) + @classmethod + @validate_hf_hub_args + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, + subfolder: Optional[str] = None, + return_unused_kwargs=False, + **kwargs, + ) -> Self: + r""" + Instantiate a guider from a pre-defined JSON configuration file in a local directory or Hub repository. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the guider + configuration saved with [`~BaseGuidance.save_pretrained`]. + subfolder (`str`, *optional*): + The subfolder location of a model file within a larger model repository on the Hub or locally. + return_unused_kwargs (`bool`, *optional*, defaults to `False`): + Whether kwargs that are not consumed by the Python class should be returned or not. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + + + + To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with + `huggingface-cli login`. You can also activate the special + ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a + firewalled environment. + + + + """ + config, kwargs, commit_hash = cls.load_config( + pretrained_model_name_or_path=pretrained_model_name_or_path, + subfolder=subfolder, + return_unused_kwargs=True, + return_commit_hash=True, + **kwargs, + ) + return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs) + + def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + """ + Save a guider configuration object to a directory so that it can be reloaded using the + [`~BaseGuidance.from_pretrained`] class method. + + Args: + save_directory (`str` or `os.PathLike`): + Directory where the configuration JSON file will be saved (will be created if it does not exist). + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): r""" From 179d6d958b5af63c0511667a347a1a42aff2dc61 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 7 Jul 2025 09:50:33 +0200 Subject: [PATCH 151/170] add subfolder to push_to_hub --- src/diffusers/utils/hub_utils.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index f80f96a3425d..637f64da85aa 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -467,6 +467,7 @@ def _upload_folder( token: Optional[str] = None, commit_message: Optional[str] = None, create_pr: bool = False, + subfolder: Optional[str] = None, ): """ Uploads all files in `working_dir` to `repo_id`. @@ -481,7 +482,7 @@ def _upload_folder( logger.info(f"Uploading the files of {working_dir} to {repo_id}.") return upload_folder( - repo_id=repo_id, folder_path=working_dir, token=token, commit_message=commit_message, create_pr=create_pr + repo_id=repo_id, folder_path=working_dir, token=token, commit_message=commit_message, create_pr=create_pr, path_in_repo=subfolder ) def push_to_hub( @@ -493,6 +494,7 @@ def push_to_hub( create_pr: bool = False, safe_serialization: bool = True, variant: Optional[str] = None, + subfolder: Optional[str] = None, ) -> str: """ Upload model, scheduler, or pipeline files to the 🤗 Hugging Face Hub. @@ -534,8 +536,9 @@ def push_to_hub( repo_id = create_repo(repo_id, private=private, token=token, exist_ok=True).repo_id # Create a new empty model card and eventually tag it - model_card = load_or_create_model_card(repo_id, token=token) - model_card = populate_model_card(model_card) + if not subfolder: + model_card = load_or_create_model_card(repo_id, token=token) + model_card = populate_model_card(model_card) # Save all files. save_kwargs = {"safe_serialization": safe_serialization} @@ -546,7 +549,8 @@ def push_to_hub( self.save_pretrained(tmpdir, **save_kwargs) # Update model card if needed: - model_card.save(os.path.join(tmpdir, "README.md")) + if not subfolder: + model_card.save(os.path.join(tmpdir, "README.md")) return self._upload_folder( tmpdir, @@ -554,4 +558,5 @@ def push_to_hub( token=token, commit_message=commit_message, create_pr=create_pr, + subfolder=subfolder, ) From 5af003a9e13a18cfc2c5fa002d17074fbe18e5c0 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 7 Jul 2025 09:51:04 +0200 Subject: [PATCH 152/170] update from_componeenet, update_component --- .../modular_pipelines/modular_pipeline.py | 262 ++++++++++++++---- .../modular_pipeline_utils.py | 63 +++-- 2 files changed, 257 insertions(+), 68 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index f23c5644f828..c01c6411df92 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -47,7 +47,8 @@ format_intermediates_short, make_doc_string, ) - +from huggingface_hub import create_repo +from ..utils.hub_utils import load_or_create_model_card, populate_model_card if is_accelerate_available(): import accelerate @@ -1665,7 +1666,45 @@ def __init__( **kwargs, ): """ - Initialize the loader with a list of component specs and config specs. + Initialize a ModularPipeline instance. + + This method sets up the pipeline by: + 1. creating default pipeline blocks if not provided + 2. gather component and config specifications based on the pipeline blocks's requirement (e.g. expected_components, expected_configs) + 3. update the loading specs of from_pretrained components based on the modular_model_index.json file from huggingface hub if `pretrained_model_name_or_path` is provided + 4. create defaultfrom_config components and register everything + + Args: + blocks: `ModularPipelineBlocks` instance. If None, will attempt to load + default blocks based on the pipeline class name. + pretrained_model_name_or_path: Path to a pretrained pipeline configuration. If provided, + will load component specs (only for from_pretrained components) and config values from the saved modular_model_index.json file. + components_manager: Optional ComponentsManager for managing multiple component cross different pipelines and apply offloading strategies. + collection: Optional collection name for organizing components in the ComponentsManager. + **kwargs: Additional arguments passed to `load_config()` when loading pretrained configuration. + + Examples: + ```python + # Initialize with custom blocks + pipeline = ModularPipeline(blocks=my_custom_blocks) + + # Initialize from pretrained configuration + pipeline = ModularPipeline(blocks=my_blocks, pretrained_model_name_or_path="my-repo/modular-pipeline") + + # Initialize with components manager + pipeline = ModularPipeline( + blocks=my_blocks, + components_manager=ComponentsManager(), + collection="my_collection" + ) + ``` + + Notes: + - If blocks is None, the method will try to find default blocks based on the pipeline class name + - Components with default_creation_method="from_config" are created immediately, its specs are not included in config dict and will not be saved in `modular_model_index.json` + - Components with default_creation_method="from_pretrained" are set to None and can be loaded later with `load_default_components()`/`load_components()` + - The pipeline's config dict is populated with component specs (only for from_pretrained components) and config values, which will be saved as `modular_model_index.json` during `save_pretrained` + - The pipeline's config dict is also used to store the pipeline blocks's class name, which will be saved as `_blocks_class_name` in the config dict """ if blocks is None: blocks_class_name = MODULAR_PIPELINE_BLOCKS_MAPPING.get(self.__class__.__name__) @@ -1715,6 +1754,10 @@ def __init__( @property def default_call_parameters(self) -> Dict[str, Any]: + """ + Returns: + - Dictionary mapping input names to their default values + """ params = {} for input_param in self.blocks.inputs: params[input_param.name] = input_param.default @@ -1722,7 +1765,40 @@ def default_call_parameters(self) -> Dict[str, Any]: def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs): """ - Run one or more blocks in sequence, optionally you can pass a previous pipeline state. + Execute the pipeline by running the pipeline blocks with the given inputs. + + Args: + state (`PipelineState`, optional): + PipelineState instance contains inputs and intermediate values. If None, a new `PipelineState` will be created based on the user inputs and the pipeline blocks's requirement. + output (`str` or `List[str]`, optional): + Optional specification of what to return: + - None: Returns the complete `PipelineState` with all inputs and intermediates (default) + - str: Returns a specific intermediate value from the state (e.g. `output="image"`) + - List[str]: Returns a dictionary of specific intermediate values (e.g. `output=["image", "latents"]`) + + + Examples: + ```python + # Get complete pipeline state + state = pipeline(prompt="A beautiful sunset", num_inference_steps=20) + print(state.intermediates) # All intermediate outputs + + # Get specific output + image = pipeline(prompt="A beautiful sunset", output="image") + + # Get multiple specific outputs + results = pipeline(prompt="A beautiful sunset", output=["image", "latents"]) + image, latents = results["image"], results["latents"] + + # Continue from previous state + state = pipeline(prompt="A beautiful sunset") + new_state = pipeline(state=state, output="image") # Continue processing + ``` + + Returns: + - If `output` is None: Complete `PipelineState` containing all inputs and intermediates + - If `output` is str: The specific intermediate value from the state (e.g. `output="image"`) + - If `output` is List[str]: Dictionary mapping output names to their values from the state (e.g. `output=["image", "latents"]`) """ if state is None: state = PipelineState() @@ -1776,6 +1852,12 @@ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = raise ValueError(f"Output '{output}' is not a valid output type") def load_default_components(self, **kwargs): + """ + Load from_pretrained components using the loading specs in the config dict. + + Args: + **kwargs: Additional arguments passed to `load_components()` method + """ names = [ name for name in self._component_specs.keys() @@ -1793,6 +1875,19 @@ def from_pretrained( collection: Optional[str] = None, **kwargs, ): + """ + Load a ModularPipeline from a huggingface hub repo. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`, optional): + Path to a pretrained pipeline configuration. If provided, will load component specs (only for from_pretrained components) and config values from the modular_model_index.json file. + trust_remote_code (`bool`, optional): + Whether to trust remote code when loading the pipeline, need to be set to True if you want to create pipeline blocks based on the custom code in `pretrained_model_name_or_path` + components_manager (`ComponentsManager`, optional): + ComponentsManager instance for managing multiple component cross different pipelines and apply offloading strategies. + collection (`str`, optional):` + Collection name for organizing components in the ComponentsManager. + """ from ..pipelines.pipeline_loading_utils import _get_pipeline_class try: @@ -1830,14 +1925,50 @@ def from_pretrained( ) return pipeline - # YiYi TODO: - # 1. should support save some components too! currently only modular_model_index.json is saved - # 2. maybe order the json file to make it more readable: configs first, then components def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): - self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) + """ + Save the pipeline to a directory. It does not save components, you need to save them separately. + + Args: + save_directory (`str` or `os.PathLike`): + Path to the directory where the pipeline will be saved. + push_to_hub (`bool`, optional): + Whether to push the pipeline to the huggingface hub. + **kwargs: Additional arguments passed to `save_config()` method + + + """ + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + private = kwargs.pop("private", None) + create_pr = kwargs.pop("create_pr", False) + token = kwargs.pop("token", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id + + # Create a new empty model card and eventually tag it + model_card = load_or_create_model_card(repo_id, token=token, is_pipeline=True) + model_card = populate_model_card(model_card) + model_card.save(os.path.join(save_directory, "README.md")) + + # YiYi TODO: maybe order the json file to make it more readable: configs first, then components + self.save_config(save_directory=save_directory) + + if push_to_hub: + self._upload_folder( + save_directory, + repo_id, + token=token, + commit_message=commit_message, + create_pr=create_pr, + ) @property def doc(self): + """ + Returns: + - The docstring of the pipeline blocks + """ return self.blocks.doc def register_components(self, **kwargs): @@ -1846,25 +1977,24 @@ def register_components(self, **kwargs): This method is responsible for: 1. Sets component objects as attributes on the loader (e.g., self.unet = unet) - 2. Updates the modular_model_index.json configuration for serialization (only for from_pretrained components) + 2. Updates the config dict, which will be saved as `modular_model_index.json` during `save_pretrained` (only for from_pretrained components) 3. Adds components to the component manager if one is attached (only for from_pretrained components) This method is called when: - Components are first initialized in __init__: - from_pretrained components not loaded during __init__ so they are registered as None; - non from_pretrained components are created during __init__ and registered as the object itself - - Components are updated with the `update()` method: e.g. loader.update(unet=unet) or - loader.update(guider=guider_spec) - - (from_pretrained) Components are loaded with the `load()` method: e.g. loader.load(names=["unet"]) + - Components are updated with the `update_components()` method: e.g. loader.update_components(unet=unet) or + loader.update_components(guider=guider_spec) + - (from_pretrained) Components are loaded with the `load_default_components()` method: e.g. loader.load_default_components(names=["unet"]) Args: **kwargs: Keyword arguments where keys are component names and values are component objects. E.g., register_components(unet=unet_model, text_encoder=encoder_model) Notes: - - Components must be created from ComponentSpec (have _diffusers_load_id attribute) - - When registering None for a component, it sets attribute to None but still syncs specs with the - modular_model_index.json config + - When registering None for a component, it sets attribute to None but still syncs specs with the config dict, which will be saved as `modular_model_index.json` during `save_pretrained` + - component_specs are updated to match the new component outside of this method, e.g. in `update_components()` method """ for name, module in kwargs.items(): # current component spec @@ -1877,10 +2007,6 @@ def register_components(self, **kwargs): is_registered = hasattr(self, name) is_from_pretrained = component_spec.default_creation_method == "from_pretrained" - # make sure the component is created from ComponentSpec - if module is not None and not hasattr(module, "_diffusers_load_id"): - raise ValueError("`ModularPipeline` only supports components created from `ComponentSpec`.") - if module is not None: # actual library and class name of the module library, class_name = _fetch_class_library_tuple(module) # e.g. ("diffusers", "UNet2DConditionModel") @@ -1906,7 +2032,7 @@ def register_components(self, **kwargs): if is_from_pretrained: self.register_to_config(**register_dict) setattr(self, name, module) - if module is not None and module._diffusers_load_id != "null" and self._components_manager is not None: + if module is not None and is_from_pretrained and self._components_manager is not None: self._components_manager.add(name, module, self._collection) continue @@ -1942,7 +2068,7 @@ def register_components(self, **kwargs): # finally set models setattr(self, name, module) # add to component manager if one is attached - if module is not None and module._diffusers_load_id != "null" and self._components_manager is not None: + if module is not None and is_from_pretrained and self._components_manager is not None: self._components_manager.add(name, module, self._collection) @property @@ -1998,14 +2124,26 @@ def dtype(self) -> torch.dtype: @property def null_component_names(self) -> List[str]: + """ + Returns: + - List of names for components that needs to be loaded + """ return [name for name in self._component_specs.keys() if hasattr(self, name) and getattr(self, name) is None] @property def component_names(self) -> List[str]: + """ + Returns: + - List of names for all components + """ return list(self.components.keys()) @property def pretrained_component_names(self) -> List[str]: + """ + Returns: + - List of names for from_pretrained components + """ return [ name for name in self._component_specs.keys() @@ -2014,6 +2152,10 @@ def pretrained_component_names(self) -> List[str]: @property def config_component_names(self) -> List[str]: + """ + Returns: + - List of names for from_config components + """ return [ name for name in self._component_specs.keys() @@ -2022,44 +2164,60 @@ def config_component_names(self) -> List[str]: @property def components(self) -> Dict[str, Any]: + """ + Returns: + - Dictionary mapping component names to their objects (include both from_pretrained and from_config components) + """ # return only components we've actually set as attributes on self return {name: getattr(self, name) for name in self._component_specs.keys() if hasattr(self, name)} def get_component_spec(self, name: str) -> ComponentSpec: + """ + Returns: + - a copy of the ComponentSpec object for the given component name + """ return deepcopy(self._component_specs[name]) def update_components(self, **kwargs): """ - Update components and configuration values after the loader has been instantiated. + Update components and configuration values and specs after the pipeline has been instantiated. This method allows you to: - 1. Replace existing components with new ones (e.g., updating the unet or text_encoder) - 2. Update configuration values (e.g., changing requires_safety_checker flag) + 1. Replace existing components with new ones (e.g., updating `self.unet` or `self.text_encoder`) + 2. Update configuration values (e.g., changing `self.requires_safety_checker` flag) + + In addition to updating the components and configuration values as pipeline attributes, the method also updates: + - the corresponding specs in `_component_specs` and `_config_specs` + - the `config` dict, which will be saved as `modular_model_index.json` during `save_pretrained` Args: - **kwargs: Component objects or configuration values to update: - - Component objects: Must be created using ComponentSpec (e.g., `unet=new_unet, - text_encoder=new_encoder`) - - Configuration values: Simple values to update configuration settings (e.g., - `requires_safety_checker=False`) - - ComponentSpec objects: if passed a ComponentSpec object, only support from_config spec, will call - create() method to create it + **kwargs: Component objects, ComponentSpec objects, or configuration values to update: + - Component objects: Only supports components we can extract specs using `ComponentSpec.from_component()` method + i.e. components created with ComponentSpec.load() or ConfigMixin subclasses that aren't nn.Modules + (e.g., `unet=new_unet, text_encoder=new_encoder`) + - ComponentSpec objects: Only supports default_creation_method == "from_config", will call create() method to create a new component + (e.g., `guider=ComponentSpec(name="guider", type_hint=ClassifierFreeGuidance, config={...}, default_creation_method="from_config")`) + - Configuration values: Simple values to update configuration settings + (e.g., `requires_safety_checker=False`) Raises: - ValueError: If a component wasn't created using ComponentSpec (doesn't have `_diffusers_load_id` attribute) + ValueError: If a component object is not supported in ComponentSpec.from_component() method: + - nn.Module components without a valid `_diffusers_load_id` attribute + - Non-ConfigMixin components without a valid `_diffusers_load_id` attribute Examples: ```python # Update multiple components at once - loader.update(unet=new_unet_model, text_encoder=new_text_encoder) + pipeline.update_components(unet=new_unet_model, text_encoder=new_text_encoder) # Update configuration values - loader.update(requires_safety_checker=False) + pipeline.update_components(requires_safety_checker=False) # Update both components and configs together - loader.update(unet=new_unet_model, requires_safety_checker=False) - # update with ComponentSpec objects - loader.update( + pipeline.update_components(unet=new_unet_model, requires_safety_checker=False) + + # Update with ComponentSpec objects (from_config only) + pipeline.update_components( guider=ComponentSpec( name="guider", type_hint=ClassifierFreeGuidance, @@ -2068,6 +2226,11 @@ def update_components(self, **kwargs): ) ) ``` + + Notes: + - Components with trained weights must be created using ComponentSpec.load(). If the component has not been shared in huggingface hub and you don't have loading specs, you can upload it using `push_to_hub()` + - ConfigMixin objects without weights (e.g., schedulers, guiders) can be passed directly + - ComponentSpec objects with default_creation_method="from_pretrained" are not supported in update_components() """ # extract component_specs_updates & config_specs_updates from `specs` @@ -2080,28 +2243,23 @@ def update_components(self, **kwargs): passed_config_values = {k: kwargs.pop(k) for k in self._config_specs if k in kwargs} for name, component in passed_components.items(): - if not hasattr(component, "_diffusers_load_id"): - raise ValueError("`ModularPipeline` only supports components created from `ComponentSpec`.") - - # YiYi TODO: remove this if we remove support for non config mixin components in `create()` method - if component._diffusers_load_id == "null" and not isinstance(component, ConfigMixin): - raise ValueError( - f"The passed component '{name}' is not supported in update() method " - f"because it is not supported in `ComponentSpec.from_component()`. " - f"Please pass a ComponentSpec object instead." - ) current_component_spec = self._component_specs[name] + # warn if type changed if current_component_spec.type_hint is not None and not isinstance( component, current_component_spec.type_hint ): logger.warning( - f"ModularPipeline.update: adding {name} with new type: {component.__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}" + f"ModularPipeline.update_components: adding {name} with new type: {component.__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}" ) # update _component_specs based on the new component new_component_spec = ComponentSpec.from_component(name, component) + if new_component_spec.default_creation_method != current_component_spec.default_creation_method: + logger.warning(f"ModularPipeline.update_components: changing the default_creation_method of {name} from {current_component_spec.default_creation_method} to {new_component_spec.default_creation_method}.") + self._component_specs[name] = new_component_spec + if len(kwargs) > 0: logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}") @@ -2109,7 +2267,7 @@ def update_components(self, **kwargs): for name, component_spec in passed_component_specs.items(): if component_spec.default_creation_method == "from_pretrained": raise ValueError( - "ComponentSpec object with default_creation_method == 'from_pretrained' is not supported in update() method" + "ComponentSpec object with default_creation_method == 'from_pretrained' is not supported in update_components() method" ) created_components[name] = component_spec.create() current_component_spec = self._component_specs[name] @@ -2118,7 +2276,7 @@ def update_components(self, **kwargs): created_components[name], current_component_spec.type_hint ): logger.warning( - f"ModularPipeline.update: adding {name} with new type: {created_components[name].__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}" + f"ModularPipeline.update_components: adding {name} with new type: {created_components[name].__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}" ) # update _component_specs based on the user passed component_spec self._component_specs[name] = component_spec @@ -2145,7 +2303,7 @@ def load_components(self, names: Union[List[str], str], **kwargs): - if potentially override ComponentSpec if passed a different loading field in kwargs, e.g. `repo`, `variant`, `revision`, etc. """ - # if not pass any names, will not load any components + if isinstance(names, str): names = [names] elif not isinstance(names, list): @@ -2393,8 +2551,8 @@ def module_is_offloaded(module): @staticmethod def _component_spec_to_dict(component_spec: ComponentSpec) -> Any: """ - Convert a ComponentSpec into a JSON‐serializable dict for saving in `modular_model_index.json`. If the - default_creation_method is not from_pretrained, return None. + Convert a ComponentSpec into a JSON‐serializable dict for saving as an entry in `modular_model_index.json`. + If the `default_creation_method` is not `from_pretrained`, return None. This dict contains: - "type_hint": Tuple[str, str] diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index 90f4586753d4..f33829d7ed89 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -19,12 +19,13 @@ from typing import Any, Dict, List, Literal, Optional, Type, Union from ..configuration_utils import ConfigMixin, FrozenDict -from ..utils.import_utils import is_torch_available - +from ..utils import is_torch_available, logging +import torch if is_torch_available(): pass +logger = logging.get_logger(__name__) # pylint: disable=invalid-name class InsertableDict(OrderedDict): def insert(self, key, value, index): @@ -110,28 +111,58 @@ def __eq__(self, other): @classmethod def from_component(cls, name: str, component: Any) -> Any: - """Create a ComponentSpec from a Component created by `create` or `load` method.""" + """Create a ComponentSpec from a Component. + + Currently supports: + - Components created with `ComponentSpec.load()` method + - Components that are ConfigMixin subclasses but not nn.Modules (e.g. schedulers, guiders) + + Args: + name: Name of the component + component: Component object to create spec from + + Returns: + ComponentSpec object + + Raises: + ValueError: If component is not supported (e.g. nn.Module without load_id, non-ConfigMixin) + """ + + # Check if component was created with ComponentSpec.load() + if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null": + # component has a usable load_id -> from_pretrained, no warning needed + default_creation_method = "from_pretrained" + else: + # Component doesn't have a usable load_id, check if it's a nn.Module + if isinstance(component, torch.nn.Module): + raise ValueError( + "Cannot create ComponentSpec from a nn.Module that was not created with `ComponentSpec.load()` method." + ) + # ConfigMixin objects without weights (e.g. scheduler & guider) can be recreated with from_config + elif isinstance(component, ConfigMixin): + # warn if component was not created with `ComponentSpec` + if not hasattr(component, "_diffusers_load_id"): + logger.warning("Component was not created using `ComponentSpec`, defaulting to `from_config` creation method") + default_creation_method = "from_config" + else: + # Not a ConfigMixin and not created with `ComponentSpec.load()` method -> throw error + raise ValueError( + f"Cannot create ComponentSpec from {name}({component.__class__.__name__}). Currently ComponentSpec.from_component() only supports: " + f" - components created with `ComponentSpec.load()` method" + f" - components that are a subclass of ConfigMixin but not a nn.Module (e.g. guider, scheduler)." + ) - if not hasattr(component, "_diffusers_load_id"): - raise ValueError("Component is not created by `create` or `load` method") - # throw a error if component is created with `create` method but not a subclass of ConfigMixin - # YiYi TODO: remove this check if we remove support for non configmixin in `create()` method - if component._diffusers_load_id == "null" and not isinstance(component, ConfigMixin): - raise ValueError( - "We currently only support creating ComponentSpec from a component with " - "created with `ComponentSpec.load` method" - "or created with `ComponentSpec.create` and a subclass of ConfigMixin" - ) type_hint = component.__class__ - default_creation_method = "from_config" if component._diffusers_load_id == "null" else "from_pretrained" if isinstance(component, ConfigMixin) and default_creation_method == "from_config": config = component.config else: config = None - - load_spec = cls.decode_load_id(component._diffusers_load_id) + if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null": + load_spec = cls.decode_load_id(component._diffusers_load_id) + else: + load_spec = {} return cls( name=name, type_hint=type_hint, config=config, default_creation_method=default_creation_method, **load_spec From 0fcdd699cfd81065bf92da976af0a7c1e8198a8b Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 7 Jul 2025 09:55:04 +0200 Subject: [PATCH 153/170] style --- src/diffusers/guiders/guider_utils.py | 11 +- .../modular_pipelines/modular_pipeline.py | 112 +++++++++++------- .../modular_pipeline_utils.py | 18 +-- src/diffusers/utils/hub_utils.py | 7 +- 4 files changed, 89 insertions(+), 59 deletions(-) diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index 22bc8bae172e..1c0b8cb286e7 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -12,19 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union, Optional +import os +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import torch from huggingface_hub.utils import validate_hf_hub_args from typing_extensions import Self -import os - from ..configuration_utils import ConfigMixin from ..utils import PushToHubMixin, get_logger - if TYPE_CHECKING: from ..modular_pipelines.modular_pipeline import BlockState @@ -221,8 +219,8 @@ def from_pretrained( - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on the Hub. - - A path to a *directory* (for example `./my_model_directory`) containing the guider - configuration saved with [`~BaseGuidance.save_pretrained`]. + - A path to a *directory* (for example `./my_model_directory`) containing the guider configuration + saved with [`~BaseGuidance.save_pretrained`]. subfolder (`str`, *optional*): The subfolder location of a model file within a larger model repository on the Hub or locally. return_unused_kwargs (`bool`, *optional*, defaults to `False`): @@ -285,6 +283,7 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: """ self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) + def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): r""" Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index c01c6411df92..7850eff744ad 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -22,6 +22,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch +from huggingface_hub import create_repo from huggingface_hub.utils import validate_hf_hub_args from tqdm.auto import tqdm from typing_extensions import Self @@ -34,6 +35,7 @@ logging, ) from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code +from ..utils.hub_utils import load_or_create_model_card, populate_model_card from .components_manager import ComponentsManager from .modular_pipeline_utils import ( ComponentSpec, @@ -47,8 +49,7 @@ format_intermediates_short, make_doc_string, ) -from huggingface_hub import create_repo -from ..utils.hub_utils import load_or_create_model_card, populate_model_card + if is_accelerate_available(): import accelerate @@ -1670,16 +1671,21 @@ def __init__( This method sets up the pipeline by: 1. creating default pipeline blocks if not provided - 2. gather component and config specifications based on the pipeline blocks's requirement (e.g. expected_components, expected_configs) - 3. update the loading specs of from_pretrained components based on the modular_model_index.json file from huggingface hub if `pretrained_model_name_or_path` is provided + 2. gather component and config specifications based on the pipeline blocks's requirement (e.g. + expected_components, expected_configs) + 3. update the loading specs of from_pretrained components based on the modular_model_index.json file from + huggingface hub if `pretrained_model_name_or_path` is provided 4. create defaultfrom_config components and register everything Args: blocks: `ModularPipelineBlocks` instance. If None, will attempt to load default blocks based on the pipeline class name. pretrained_model_name_or_path: Path to a pretrained pipeline configuration. If provided, - will load component specs (only for from_pretrained components) and config values from the saved modular_model_index.json file. - components_manager: Optional ComponentsManager for managing multiple component cross different pipelines and apply offloading strategies. + will load component specs (only for from_pretrained components) and config values from the saved + modular_model_index.json file. + components_manager: + Optional ComponentsManager for managing multiple component cross different pipelines and apply + offloading strategies. collection: Optional collection name for organizing components in the ComponentsManager. **kwargs: Additional arguments passed to `load_config()` when loading pretrained configuration. @@ -1693,18 +1699,20 @@ def __init__( # Initialize with components manager pipeline = ModularPipeline( - blocks=my_blocks, - components_manager=ComponentsManager(), - collection="my_collection" + blocks=my_blocks, components_manager=ComponentsManager(), collection="my_collection" ) ``` Notes: - If blocks is None, the method will try to find default blocks based on the pipeline class name - - Components with default_creation_method="from_config" are created immediately, its specs are not included in config dict and will not be saved in `modular_model_index.json` - - Components with default_creation_method="from_pretrained" are set to None and can be loaded later with `load_default_components()`/`load_components()` - - The pipeline's config dict is populated with component specs (only for from_pretrained components) and config values, which will be saved as `modular_model_index.json` during `save_pretrained` - - The pipeline's config dict is also used to store the pipeline blocks's class name, which will be saved as `_blocks_class_name` in the config dict + - Components with default_creation_method="from_config" are created immediately, its specs are not included + in config dict and will not be saved in `modular_model_index.json` + - Components with default_creation_method="from_pretrained" are set to None and can be loaded later with + `load_default_components()`/`load_components()` + - The pipeline's config dict is populated with component specs (only for from_pretrained components) and + config values, which will be saved as `modular_model_index.json` during `save_pretrained` + - The pipeline's config dict is also used to store the pipeline blocks's class name, which will be saved as + `_blocks_class_name` in the config dict """ if blocks is None: blocks_class_name = MODULAR_PIPELINE_BLOCKS_MAPPING.get(self.__class__.__name__) @@ -1769,12 +1777,14 @@ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = Args: state (`PipelineState`, optional): - PipelineState instance contains inputs and intermediate values. If None, a new `PipelineState` will be created based on the user inputs and the pipeline blocks's requirement. + PipelineState instance contains inputs and intermediate values. If None, a new `PipelineState` will be + created based on the user inputs and the pipeline blocks's requirement. output (`str` or `List[str]`, optional): Optional specification of what to return: - None: Returns the complete `PipelineState` with all inputs and intermediates (default) - str: Returns a specific intermediate value from the state (e.g. `output="image"`) - - List[str]: Returns a dictionary of specific intermediate values (e.g. `output=["image", "latents"]`) + - List[str]: Returns a dictionary of specific intermediate values (e.g. `output=["image", + "latents"]`) Examples: @@ -1794,11 +1804,12 @@ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = state = pipeline(prompt="A beautiful sunset") new_state = pipeline(state=state, output="image") # Continue processing ``` - + Returns: - If `output` is None: Complete `PipelineState` containing all inputs and intermediates - If `output` is str: The specific intermediate value from the state (e.g. `output="image"`) - - If `output` is List[str]: Dictionary mapping output names to their values from the state (e.g. `output=["image", "latents"]`) + - If `output` is List[str]: Dictionary mapping output names to their values from the state (e.g. + `output=["image", "latents"]`) """ if state is None: state = PipelineState() @@ -1880,11 +1891,14 @@ def from_pretrained( Args: pretrained_model_name_or_path (`str` or `os.PathLike`, optional): - Path to a pretrained pipeline configuration. If provided, will load component specs (only for from_pretrained components) and config values from the modular_model_index.json file. + Path to a pretrained pipeline configuration. If provided, will load component specs (only for + from_pretrained components) and config values from the modular_model_index.json file. trust_remote_code (`bool`, optional): - Whether to trust remote code when loading the pipeline, need to be set to True if you want to create pipeline blocks based on the custom code in `pretrained_model_name_or_path` + Whether to trust remote code when loading the pipeline, need to be set to True if you want to create + pipeline blocks based on the custom code in `pretrained_model_name_or_path` components_manager (`ComponentsManager`, optional): - ComponentsManager instance for managing multiple component cross different pipelines and apply offloading strategies. + ComponentsManager instance for managing multiple component cross different pipelines and apply + offloading strategies. collection (`str`, optional):` Collection name for organizing components in the ComponentsManager. """ @@ -1935,8 +1949,6 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: push_to_hub (`bool`, optional): Whether to push the pipeline to the huggingface hub. **kwargs: Additional arguments passed to `save_config()` method - - """ if push_to_hub: commit_message = kwargs.pop("commit_message", None) @@ -1945,12 +1957,12 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: token = kwargs.pop("token", None) repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id - + # Create a new empty model card and eventually tag it model_card = load_or_create_model_card(repo_id, token=token, is_pipeline=True) model_card = populate_model_card(model_card) model_card.save(os.path.join(save_directory, "README.md")) - + # YiYi TODO: maybe order the json file to make it more readable: configs first, then components self.save_config(save_directory=save_directory) @@ -1977,7 +1989,8 @@ def register_components(self, **kwargs): This method is responsible for: 1. Sets component objects as attributes on the loader (e.g., self.unet = unet) - 2. Updates the config dict, which will be saved as `modular_model_index.json` during `save_pretrained` (only for from_pretrained components) + 2. Updates the config dict, which will be saved as `modular_model_index.json` during `save_pretrained` (only + for from_pretrained components) 3. Adds components to the component manager if one is attached (only for from_pretrained components) This method is called when: @@ -1986,15 +1999,18 @@ def register_components(self, **kwargs): - non from_pretrained components are created during __init__ and registered as the object itself - Components are updated with the `update_components()` method: e.g. loader.update_components(unet=unet) or loader.update_components(guider=guider_spec) - - (from_pretrained) Components are loaded with the `load_default_components()` method: e.g. loader.load_default_components(names=["unet"]) + - (from_pretrained) Components are loaded with the `load_default_components()` method: e.g. + loader.load_default_components(names=["unet"]) Args: **kwargs: Keyword arguments where keys are component names and values are component objects. E.g., register_components(unet=unet_model, text_encoder=encoder_model) Notes: - - When registering None for a component, it sets attribute to None but still syncs specs with the config dict, which will be saved as `modular_model_index.json` during `save_pretrained` - - component_specs are updated to match the new component outside of this method, e.g. in `update_components()` method + - When registering None for a component, it sets attribute to None but still syncs specs with the config + dict, which will be saved as `modular_model_index.json` during `save_pretrained` + - component_specs are updated to match the new component outside of this method, e.g. in + `update_components()` method """ for name, module in kwargs.items(): # current component spec @@ -2166,7 +2182,8 @@ def config_component_names(self) -> List[str]: def components(self) -> Dict[str, Any]: """ Returns: - - Dictionary mapping component names to their objects (include both from_pretrained and from_config components) + - Dictionary mapping component names to their objects (include both from_pretrained and from_config + components) """ # return only components we've actually set as attributes on self return {name: getattr(self, name) for name in self._component_specs.keys() if hasattr(self, name)} @@ -2186,19 +2203,21 @@ def update_components(self, **kwargs): 1. Replace existing components with new ones (e.g., updating `self.unet` or `self.text_encoder`) 2. Update configuration values (e.g., changing `self.requires_safety_checker` flag) - In addition to updating the components and configuration values as pipeline attributes, the method also updates: + In addition to updating the components and configuration values as pipeline attributes, the method also + updates: - the corresponding specs in `_component_specs` and `_config_specs` - the `config` dict, which will be saved as `modular_model_index.json` during `save_pretrained` Args: **kwargs: Component objects, ComponentSpec objects, or configuration values to update: - - Component objects: Only supports components we can extract specs using `ComponentSpec.from_component()` method - i.e. components created with ComponentSpec.load() or ConfigMixin subclasses that aren't nn.Modules - (e.g., `unet=new_unet, text_encoder=new_encoder`) - - ComponentSpec objects: Only supports default_creation_method == "from_config", will call create() method to create a new component - (e.g., `guider=ComponentSpec(name="guider", type_hint=ClassifierFreeGuidance, config={...}, default_creation_method="from_config")`) - - Configuration values: Simple values to update configuration settings - (e.g., `requires_safety_checker=False`) + - Component objects: Only supports components we can extract specs using + `ComponentSpec.from_component()` method i.e. components created with ComponentSpec.load() or + ConfigMixin subclasses that aren't nn.Modules (e.g., `unet=new_unet, text_encoder=new_encoder`) + - ComponentSpec objects: Only supports default_creation_method == "from_config", will call create() + method to create a new component (e.g., `guider=ComponentSpec(name="guider", + type_hint=ClassifierFreeGuidance, config={...}, default_creation_method="from_config")`) + - Configuration values: Simple values to update configuration settings (e.g., + `requires_safety_checker=False`) Raises: ValueError: If a component object is not supported in ComponentSpec.from_component() method: @@ -2228,9 +2247,11 @@ def update_components(self, **kwargs): ``` Notes: - - Components with trained weights must be created using ComponentSpec.load(). If the component has not been shared in huggingface hub and you don't have loading specs, you can upload it using `push_to_hub()` + - Components with trained weights must be created using ComponentSpec.load(). If the component has not been + shared in huggingface hub and you don't have loading specs, you can upload it using `push_to_hub()` - ConfigMixin objects without weights (e.g., schedulers, guiders) can be passed directly - - ComponentSpec objects with default_creation_method="from_pretrained" are not supported in update_components() + - ComponentSpec objects with default_creation_method="from_pretrained" are not supported in + update_components() """ # extract component_specs_updates & config_specs_updates from `specs` @@ -2244,7 +2265,7 @@ def update_components(self, **kwargs): for name, component in passed_components.items(): current_component_spec = self._component_specs[name] - + # warn if type changed if current_component_spec.type_hint is not None and not isinstance( component, current_component_spec.type_hint @@ -2255,10 +2276,11 @@ def update_components(self, **kwargs): # update _component_specs based on the new component new_component_spec = ComponentSpec.from_component(name, component) if new_component_spec.default_creation_method != current_component_spec.default_creation_method: - logger.warning(f"ModularPipeline.update_components: changing the default_creation_method of {name} from {current_component_spec.default_creation_method} to {new_component_spec.default_creation_method}.") - - self._component_specs[name] = new_component_spec + logger.warning( + f"ModularPipeline.update_components: changing the default_creation_method of {name} from {current_component_spec.default_creation_method} to {new_component_spec.default_creation_method}." + ) + self._component_specs[name] = new_component_spec if len(kwargs) > 0: logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}") @@ -2551,8 +2573,8 @@ def module_is_offloaded(module): @staticmethod def _component_spec_to_dict(component_spec: ComponentSpec) -> Any: """ - Convert a ComponentSpec into a JSON‐serializable dict for saving as an entry in `modular_model_index.json`. - If the `default_creation_method` is not `from_pretrained`, return None. + Convert a ComponentSpec into a JSON‐serializable dict for saving as an entry in `modular_model_index.json`. If + the `default_creation_method` is not `from_pretrained`, return None. This dict contains: - "type_hint": Tuple[str, str] diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index f33829d7ed89..ee1f30d93d9d 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -18,15 +18,18 @@ from dataclasses import dataclass, field, fields from typing import Any, Dict, List, Literal, Optional, Type, Union +import torch + from ..configuration_utils import ConfigMixin, FrozenDict from ..utils import is_torch_available, logging -import torch + if is_torch_available(): pass logger = logging.get_logger(__name__) # pylint: disable=invalid-name + class InsertableDict(OrderedDict): def insert(self, key, value, index): items = list(self.items()) @@ -112,18 +115,18 @@ def __eq__(self, other): @classmethod def from_component(cls, name: str, component: Any) -> Any: """Create a ComponentSpec from a Component. - + Currently supports: - Components created with `ComponentSpec.load()` method - Components that are ConfigMixin subclasses but not nn.Modules (e.g. schedulers, guiders) - + Args: name: Name of the component component: Component object to create spec from - + Returns: ComponentSpec object - + Raises: ValueError: If component is not supported (e.g. nn.Module without load_id, non-ConfigMixin) """ @@ -142,7 +145,9 @@ def from_component(cls, name: str, component: Any) -> Any: elif isinstance(component, ConfigMixin): # warn if component was not created with `ComponentSpec` if not hasattr(component, "_diffusers_load_id"): - logger.warning("Component was not created using `ComponentSpec`, defaulting to `from_config` creation method") + logger.warning( + "Component was not created using `ComponentSpec`, defaulting to `from_config` creation method" + ) default_creation_method = "from_config" else: # Not a ConfigMixin and not created with `ComponentSpec.load()` method -> throw error @@ -152,7 +157,6 @@ def from_component(cls, name: str, component: Any) -> Any: f" - components that are a subclass of ConfigMixin but not a nn.Module (e.g. guider, scheduler)." ) - type_hint = component.__class__ if isinstance(component, ConfigMixin) and default_creation_method == "from_config": diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index 637f64da85aa..8aaee5b75d93 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -482,7 +482,12 @@ def _upload_folder( logger.info(f"Uploading the files of {working_dir} to {repo_id}.") return upload_folder( - repo_id=repo_id, folder_path=working_dir, token=token, commit_message=commit_message, create_pr=create_pr, path_in_repo=subfolder + repo_id=repo_id, + folder_path=working_dir, + token=token, + commit_message=commit_message, + create_pr=create_pr, + path_in_repo=subfolder, ) def push_to_hub( From ceeb3c1da3ad7e53f0e6baf33be0298c268d794b Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 7 Jul 2025 10:21:01 +0200 Subject: [PATCH 154/170] fix --- docs/source/en/modular_diffusers/getting_started.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/modular_diffusers/getting_started.md b/docs/source/en/modular_diffusers/getting_started.md index 8b2cfa471525..fce46527d123 100644 --- a/docs/source/en/modular_diffusers/getting_started.md +++ b/docs/source/en/modular_diffusers/getting_started.md @@ -316,7 +316,7 @@ You typically don't need to manually create or manage these state objects. The ` A `DiffusionPipeline` defines `model_index.json` to configure its components. However, repositories for Modular Diffusers work with `modular_model_index.json`. Let's walk through the differences here. In standard `model_index.json`, each component entry is a `(library, class)` tuple: - +```py "text_encoder": [ "transformers", "CLIPTextModel" From 6521f599b2d54b64ac3a6a1aa3f249c0d65e3c80 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 7 Jul 2025 20:52:37 +0200 Subject: [PATCH 155/170] make sure modularpipeline from_pretrained works without modular_model_index --- .../modular_pipelines/modular_pipeline.py | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 7850eff744ad..5440e5e5a6ff 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -792,7 +792,7 @@ def fn_recursive_get_trigger(blocks): trigger_values.update(t for t in block.block_trigger_inputs if t is not None) # If block has sub_blocks, recursively check them - if hasattr(block, "sub_blocks"): + if block.sub_blocks: nested_triggers = fn_recursive_get_trigger(block.sub_blocks) trigger_values.update(nested_triggers) @@ -1077,7 +1077,7 @@ def fn_recursive_get_trigger(blocks): trigger_values.update(t for t in block.block_trigger_inputs if t is not None) # If block has sub_blocks, recursively check them - if hasattr(block, "sub_blocks"): + if block.sub_blocks: nested_triggers = fn_recursive_get_trigger(block.sub_blocks) trigger_values.update(nested_triggers) @@ -1098,7 +1098,7 @@ def fn_recursive_traverse(block, block_name, active_triggers): # sequential(include loopsequential) or PipelineBlock if not hasattr(block, "block_trigger_inputs"): - if hasattr(block, "sub_blocks"): + if block.sub_blocks: # sequential or LoopSequentialPipelineBlocks (keep traversing) for sub_block_name, sub_block in block.sub_blocks.items(): blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers) @@ -1128,7 +1128,7 @@ def fn_recursive_traverse(block, block_name, active_triggers): if this_block is not None: # sequential/auto (keep traversing) - if hasattr(this_block, "sub_blocks"): + if this_block.sub_blocks: result_blocks.update(fn_recursive_traverse(this_block, block_name, active_triggers)) else: # PipelineBlock @@ -1642,9 +1642,8 @@ def set_progress_bar_config(self, **kwargs): # YiYi TODO: -# 1. move the modular_repo arg and the logic to fetch info from repo out of __init__ so that __init__ alwasy create an default modular_model_index config -# 2. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess) -# 3. do we need ConfigSpec? seems pretty unnecessrary for loader, can just add and kwargs to the loader +# 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess) +# 2. do we need ConfigSpec? the are basically just key/val kwargs # 4. add validator for methods where we accpet kwargs to be passed to from_pretrained() class ModularPipeline(ConfigMixin, PushToHubMixin): """ @@ -1927,8 +1926,12 @@ def from_pretrained( "revision": revision, } - config_dict = cls.load_config(pretrained_model_name_or_path, **load_config_kwargs) - pipeline_class = _get_pipeline_class(cls, config=config_dict) + try: + config_dict = cls.load_config(pretrained_model_name_or_path, **load_config_kwargs) + pipeline_class = _get_pipeline_class(cls, config=config_dict) + except EnvironmentError: + pipeline_class = cls + pretrained_model_name_or_path = None pipeline = pipeline_class( blocks=blocks, From 863c7df543274849714eab5ea9f41e43fea37182 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 8 Jul 2025 06:15:37 +0200 Subject: [PATCH 156/170] components manager: use shorter ID, display id instead of name --- .../modular_pipelines/components_manager.py | 35 +++++-------------- 1 file changed, 9 insertions(+), 26 deletions(-) diff --git a/src/diffusers/modular_pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py index 48f37788ad21..1ccc404d1a21 100644 --- a/src/diffusers/modular_pipelines/components_manager.py +++ b/src/diffusers/modular_pipelines/components_manager.py @@ -14,7 +14,6 @@ import copy import time -import uuid from collections import OrderedDict from itertools import combinations from typing import Any, Dict, List, Optional, Union @@ -280,7 +279,7 @@ def _id_to_name(component_id: str): return "_".join(component_id.split("_")[:-1]) def add(self, name, component, collection: Optional[str] = None): - component_id = f"{name}_{uuid.uuid4()}" + component_id = f"{name}_{id(component)}" # check for duplicated components for comp_id, comp in self.components.items(): @@ -674,16 +673,6 @@ def __repr__(self): if not self.components: return "Components:\n" + "=" * 50 + "\nNo components registered.\n" + "=" * 50 - # Helper to get simple name without UUID - def get_simple_name(name): - # Extract the base name by splitting on underscore and taking first part - # This assumes names are in format "name_uuid" - parts = name.split("_") - # If we have at least 2 parts and the last part looks like a UUID, remove it - if len(parts) > 1 and len(parts[-1]) >= 8 and "-" in parts[-1]: - return "_".join(parts[:-1]) - return name - # Extract load_id if available def get_load_id(component): if hasattr(component, "_diffusers_load_id"): @@ -699,9 +688,6 @@ def format_device(component, info): exec_device = str(info["execution_device"] or "N/A") return f"{device}({exec_device})" - # Get all simple names to calculate width - simple_names = [get_simple_name(id) for id in self.components.keys()] - # Get max length of load_ids for models load_ids = [ get_load_id(component) @@ -725,7 +711,7 @@ def format_device(component, info): max_collection_len = max(10, max(len(str(c)) for c in all_collections)) if all_collections else 10 col_widths = { - "name": max(15, max(len(name) for name in simple_names)), + "id": max(15, max(len(name) for name in self.components.keys())), "class": max(25, max(len(component.__class__.__name__) for component in self.components.values())), "device": 20, "dtype": 15, @@ -748,7 +734,7 @@ def format_device(component, info): if models: output += "Models:\n" + dash_line # Column headers - output += f"{'Name':<{col_widths['name']}} | {'Class':<{col_widths['class']}} | " + output += f"{'Name_ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}} | " output += f"{'Device: act(exec)':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | " output += f"{'Size (GB)':<{col_widths['size']}} | {'Load ID':<{col_widths['load_id']}} | Collection\n" output += dash_line @@ -756,7 +742,6 @@ def format_device(component, info): # Model entries for name, component in models.items(): info = self.get_model_info(name) - simple_name = get_simple_name(name) device_str = format_device(component, info) dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A" load_id = get_load_id(component) @@ -764,14 +749,14 @@ def format_device(component, info): # Print first collection on the main line first_collection = component_collections[name][0] if component_collections[name] else "N/A" - output += f"{simple_name:<{col_widths['name']}} | {info['class_name']:<{col_widths['class']}} | " + output += f"{name:<{col_widths['id']}} | {info['class_name']:<{col_widths['class']}} | " output += f"{device_str:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | " output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {first_collection}\n" # Print additional collections on separate lines if they exist for i in range(1, len(component_collections[name])): collection = component_collections[name][i] - output += f"{'':<{col_widths['name']}} | {'':<{col_widths['class']}} | " + output += f"{'':<{col_widths['id']}} | {'':<{col_widths['class']}} | " output += f"{'':<{col_widths['device']}} | {'':<{col_widths['dtype']}} | " output += f"{'':<{col_widths['size']}} | {'':<{col_widths['load_id']}} | {collection}\n" @@ -783,23 +768,22 @@ def format_device(component, info): output += "\n" output += "Other Components:\n" + dash_line # Column headers for other components - output += f"{'Name':<{col_widths['name']}} | {'Class':<{col_widths['class']}} | Collection\n" + output += f"{'ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}} | Collection\n" output += dash_line # Other component entries for name, component in others.items(): info = self.get_model_info(name) - simple_name = get_simple_name(name) # Print first collection on the main line first_collection = component_collections[name][0] if component_collections[name] else "N/A" - output += f"{simple_name:<{col_widths['name']}} | {component.__class__.__name__:<{col_widths['class']}} | {first_collection}\n" + output += f"{name:<{col_widths['id']}} | {component.__class__.__name__:<{col_widths['class']}} | {first_collection}\n" # Print additional collections on separate lines if they exist for i in range(1, len(component_collections[name])): collection = component_collections[name][i] - output += f"{'':<{col_widths['name']}} | {'':<{col_widths['class']}} | {collection}\n" + output += f"{'':<{col_widths['id']}} | {'':<{col_widths['class']}} | {collection}\n" output += dash_line @@ -808,8 +792,7 @@ def format_device(component, info): for name in self.components: info = self.get_model_info(name) if info is not None and (info.get("adapters") is not None or info.get("ip_adapter")): - simple_name = get_simple_name(name) - output += f"\n{simple_name}:\n" + output += f"\n{name}:\n" if info.get("adapters") is not None: output += f" Adapters: {info['adapters']}\n" if info.get("ip_adapter"): From a2da0004ee30a93a4750431c74cf3beb8169a992 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 8 Jul 2025 06:16:26 +0200 Subject: [PATCH 157/170] add a guide on components manager --- docs/source/en/_toctree.yml | 2 + .../modular_diffusers/components_manager.md | 504 ++++++++++++++++++ .../en/modular_diffusers/getting_started.md | 398 -------------- 3 files changed, 506 insertions(+), 398 deletions(-) create mode 100644 docs/source/en/modular_diffusers/components_manager.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 843c07c8a602..24e21a7a4acb 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -96,6 +96,8 @@ - sections: - local: modular_diffusers/getting_started title: Getting Started + - local: modular_diffusers/components_manager + title: Components Manager - local: modular_diffusers/write_own_pipeline_block title: Write your own pipeline block - local: modular_diffusers/end_to_end_guide diff --git a/docs/source/en/modular_diffusers/components_manager.md b/docs/source/en/modular_diffusers/components_manager.md new file mode 100644 index 000000000000..84ed8e7d26fb --- /dev/null +++ b/docs/source/en/modular_diffusers/components_manager.md @@ -0,0 +1,504 @@ + + +# Components Manager + +The Components Manager is a central model registry and management system in diffusers. It lets you add models then reuse them across multiple pipelines and workflows. It tracks all models in one place with useful metadata such as model size, device placement and loaded adapters (LoRA, IP-Adapter). It has mechanisms in place to prevent duplicate model instances, enables memory-efficient sharing. Most significantly, it offers offloading that works across pipelines — unlike regular DiffusionPipeline offloading which is limited to one pipeline with predefined sequences, the Components Manager automatically manages your device memory across all your models and workflows. + + +## Basic Operations + +Let's start with the fundamental operations. First, create a Components Manager: + +```py +from diffusers import ComponentsManager +comp = ComponentsManager() +``` + +Use the `add(name, component)` method to register a component. It returns a unique ID that combines the component name with the object's unique identifier (using Python's `id()` function): + +```py +from diffusers import AutoModel +text_encoder = AutoModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder") +# Returns component_id like 'text_encoder_139917733042864' +component_id = comp.add("text_encoder", text_encoder) +``` + +You can view all registered components and their metadata: + +```py +>>> comp +Components: +=============================================================================================================================================== +Models: +----------------------------------------------------------------------------------------------------------------------------------------------- +Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection +----------------------------------------------------------------------------------------------------------------------------------------------- +text_encoder_139917733042864 | CLIPTextModel | cpu | torch.float32 | 0.46 | N/A | N/A +----------------------------------------------------------------------------------------------------------------------------------------------- + +Additional Component Info: +================================================== +``` + +And remove components using their unique ID: + +```py +comp.remove("text_encoder_139917733042864") +``` + +## Duplicate Detection + +The Components Manager automatically detects and prevents duplicate model instances to save memory and avoid confusion. Let's walk through how this works in practice. + +When you try to add the same object twice, the manager will warn you and return the existing ID: + +```py +>>> comp.add("text_encoder", text_encoder) +'text_encoder_139917733042864' +>>> comp.add("text_encoder", text_encoder) +ComponentsManager: component 'text_encoder' already exists as 'text_encoder_139917733042864' +'text_encoder_139917733042864' +``` + +Even if you add the same object under a different name, it will still be detected as a duplicate: + +```py +>>> comp.add("clip", text_encoder) +ComponentsManager: adding component 'clip' as 'clip_139917733042864', but it is duplicate of 'text_encoder_139917733042864' +To remove a duplicate, call `components_manager.remove('')`. +'clip_139917733042864' +``` + +However, there's a more subtle case where duplicate detection becomes tricky. When you load the same model into different objects, the manager can't detect duplicates unless you use `ComponentSpec`. For example: + +```py +>>> text_encoder_2 = AutoModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder") +>>> comp.add("text_encoder", text_encoder_2) +'text_encoder_139917732983664' +``` + +This creates a problem - you now have two copies of the same model consuming double the memory: + +```py +>>> comp +Components: +=============================================================================================================================================== +Models: +----------------------------------------------------------------------------------------------------------------------------------------------- +Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection +----------------------------------------------------------------------------------------------------------------------------------------------- +text_encoder_139917733042864 | CLIPTextModel | cpu | torch.float32 | 0.46 | N/A | N/A +clip_139917733042864 | CLIPTextModel | cpu | torch.float32 | 0.46 | N/A | N/A +text_encoder_139917732983664 | CLIPTextModel | cpu | torch.float32 | 0.46 | N/A | N/A +----------------------------------------------------------------------------------------------------------------------------------------------- + +Additional Component Info: +================================================== +``` + +We recommend using `ComponentSpec` to load your models. Models loaded with `ComponentSpec` get tagged with a unique ID that encodes their loading parameters, allowing the Components Manager to detect when different objects represent the same underlying checkpoint: + +```py +from diffusers import ComponentSpec, ComponentsManager +from transformers import CLIPTextModel +comp = ComponentsManager() + +# Create ComponentSpec for the first text encoder +spec = ComponentSpec(name="text_encoder", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder", type_hint=AutoModel) +# Create ComponentSpec for a duplicate text encoder (it is same checkpoint, from same repo/subfolder) +spec_duplicated = ComponentSpec(name="text_encoder_duplicated", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder", type_hint=CLIPTextModel) + +# Load and add both components - the manager will detect they're the same model +comp.add("text_encoder", spec.load()) +comp.add("text_encoder_duplicated", spec_duplicated.load()) +``` + +Now the manager detects the duplicate and warns you: + +```out +ComponentsManager: adding component 'text_encoder_duplicated_139917580682672', but it has duplicate load_id 'stabilityai/stable-diffusion-xl-base-1.0|text_encoder|null|null' with existing components: text_encoder_139918506246832. To remove a duplicate, call `components_manager.remove('')`. +'text_encoder_duplicated_139917580682672' +``` + +Both models now show the same `load_id`, making it clear they're the same model: + +```py +>>> comp +Components: +====================================================================================================================================================================================================== +Models: +------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ +Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection +------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ +text_encoder_139918506246832 | CLIPTextModel | cpu | torch.float32 | 0.46 | stabilityai/stable-diffusion-xl-base-1.0|text_encoder|null|null | N/A +text_encoder_duplicated_139917580682672 | CLIPTextModel | cpu | torch.float32 | 0.46 | stabilityai/stable-diffusion-xl-base-1.0|text_encoder|null|null | N/A +------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ + +Additional Component Info: +================================================== +``` + +## Collections + +Collections are labels you can assign to components for better organization and management. You add a component under a collection by passing the `collection=` parameter when you add the component to the manager, i.e. `add(name, component, collection=...)`. Within each collection, only one component per name is allowed - if you add a second component with the same name, the first one is automatically removed. + +Here's how collections work in practice: + +```py +comp = ComponentsManager() +# Create ComponentSpec for the first UNet (SDXL base) +spec = ComponentSpec(name="unet", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", type_hint=AutoModel) +# Create ComponentSpec for a different UNet (Juggernaut-XL) +spec2 = ComponentSpec(name="unet", repo="RunDiffusion/Juggernaut-XL-v9", subfolder="unet", type_hint=AutoModel, variant="fp16") + +# Add both UNets to the same collection - the second one will replace the first +comp.add("unet", spec.load(), collection="sdxl") +comp.add("unet", spec2.load(), collection="sdxl") +``` + +The manager automatically removes the old UNet and adds the new one: + +```out +ComponentsManager: removing existing unet from collection 'sdxl': unet_139917723891888 +'unet_139917723893136' +``` + +Only one UNet remains in the collection: + +```py +>>> comp +Components: +==================================================================================================================================================================== +Models: +-------------------------------------------------------------------------------------------------------------------------------------------------------------------- +Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection +-------------------------------------------------------------------------------------------------------------------------------------------------------------------- +unet_139917723893136 | UNet2DConditionModel | cpu | torch.float32 | 9.56 | RunDiffusion/Juggernaut-XL-v9|unet|fp16|null | sdxl +-------------------------------------------------------------------------------------------------------------------------------------------------------------------- + +Additional Component Info: +================================================== +``` + +For example, in node-based systems, you can mark all models loaded from one node with the same collection label, automatically replace models when user loads new checkpoints under same name, batch delete all models in a collection when a node is removed. + +## Retrieving Components + +The Components Manager provides several methods to retrieve registered components. + +The `get_one()` method returns a single component and supports pattern matching for the `name` parameter. You can use: +- exact matches like `comp.get_one(name="unet")` +- wildcards like `comp.get_one(name="unet*")` for components starting with "unet" +- exclusion patterns like `comp.get_one(name="!unet")` to exclude components named "unet" +- OR patterns like `comp.get_one(name="unet|vae")` to match either "unet" OR "vae". + +You can also filter by collection with `comp.get_one(name="unet", collection="sdxl")` or by load_id. If multiple components match, `get_one()` throws an error. + +Another useful method is `get_components_by_names()`, which takes a list of names and returns a dictionary mapping names to components. This is particularly helpful with modular pipelines since they provide lists of required component names, and the returned dictionary can be directly passed to `pipeline.update_components()`. + +```py +# Get components by name list +component_dict = comp.get_components_by_names(names=["text_encoder", "unet", "vae"]) +# Returns: {"text_encoder": component1, "unet": component2, "vae": component3} +``` + +## Using Components Manager with Modular Pipelines + +The Components Manager integrates seamlessly with Modular Pipelines. All you need to do is pass a Components Manager instance to `from_pretrained()` or `init_pipeline()` with an optional `collection` parameter: + +```py +from diffusers import ModularPipeline, ComponentsManager +comp = ComponentsManager() +pipe = ModularPipeline.from_pretrained("YiYiXu/modular-demo-auto", components_manager=comp, collection="test1") +``` + +By default, modular pipelines don't load components immediately, so both the pipeline and Components Manager start empty: + +```py +>>> comp +Components: +================================================== +No components registered. +================================================== +``` + +When you load components on the pipeline, they are automatically registered in the Components Manager: + +```py +>>> pipe.load_components(names="unet") +>>> comp +Components: +============================================================================================================================================================== +Models: +-------------------------------------------------------------------------------------------------------------------------------------------------------------- +Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection +-------------------------------------------------------------------------------------------------------------------------------------------------------------- +unet_139917726686304 | UNet2DConditionModel | cpu | torch.float32 | 9.56 | SG161222/RealVisXL_V4.0|unet|null|null | test1 +-------------------------------------------------------------------------------------------------------------------------------------------------------------- + +Additional Component Info: +================================================== +``` + +Now let's load all default components and then create a second pipeline that reuses all components from the first one. We pass the same Components Manager to the second pipeline but with a different collection: + +```py +# Load all default components +>>> pipe.load_default_components()` + +# Create a second pipeline using the same Components Manager but with a different collection +>>> pipe2 = ModularPipeline.from_pretrained("YiYiXu/modular-demo-auto", components_manager=comp, collection="test2") +``` + +As mentioned earlier, `ModularPipeline` has a property `null_component_names` that returns a list of component names it needs to load. We can conveniently use this list with the `get_components_by_names` method on the Components Manager: + +```py +# Get the list of components that pipe2 needs to load +>>> pipe2.null_component_names +['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'image_encoder', 'unet', 'vae', 'scheduler', 'controlnet'] + +# Retrieve all required components from the Components Manager +>>> comp_dict = comp.get_components_by_names(names=pipe2.null_component_names) + +# Update the pipeline with the retrieved components +>>> pipe2.update_components(**comp_dict) +``` + +The warnings that follow are expected and indicate that the Components Manager is correctly identifying that these components already exist and will be reused rather than creating duplicates: + +``` +ComponentsManager: component 'text_encoder' already exists as 'text_encoder_139917586016400' +ComponentsManager: component 'text_encoder_2' already exists as 'text_encoder_2_139917699973424' +ComponentsManager: component 'tokenizer' already exists as 'tokenizer_139917580599504' +ComponentsManager: component 'tokenizer_2' already exists as 'tokenizer_2_139915763443904' +ComponentsManager: component 'image_encoder' already exists as 'image_encoder_139917722468304' +ComponentsManager: component 'unet' already exists as 'unet_139917580609632' +ComponentsManager: component 'vae' already exists as 'vae_139917722459040' +ComponentsManager: component 'scheduler' already exists as 'scheduler_139916266559408' +ComponentsManager: component 'controlnet' already exists as 'controlnet_139917722454432' +``` +``` + +The pipeline is now fully loaded: + +```py +# null_component_names return empty list, meaning everything are loaded +>>> pipe2.null_component_names +[] +``` + +No new components were added to the Components Manager - we're reusing everything. All models are now associated with both `test1` and `test2` collections, showing that these components are shared across multiple pipelines: +```py +>>> comp +Components: +======================================================================================================================================================================================== +Models: +---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- +Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection +---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- +text_encoder_139917586016400 | CLIPTextModel | cpu | torch.float32 | 0.46 | SG161222/RealVisXL_V4.0|text_encoder|null|null | test1 + | | | | | | test2 +text_encoder_2_139917699973424 | CLIPTextModelWithProjection | cpu | torch.float32 | 2.59 | SG161222/RealVisXL_V4.0|text_encoder_2|null|null | test1 + | | | | | | test2 +unet_139917580609632 | UNet2DConditionModel | cpu | torch.float32 | 9.56 | SG161222/RealVisXL_V4.0|unet|null|null | test1 + | | | | | | test2 +controlnet_139917722454432 | ControlNetModel | cpu | torch.float32 | 4.66 | diffusers/controlnet-canny-sdxl-1.0|null|null|null | test1 + | | | | | | test2 +vae_139917722459040 | AutoencoderKL | cpu | torch.float32 | 0.31 | SG161222/RealVisXL_V4.0|vae|null|null | test1 + | | | | | | test2 +image_encoder_139917722468304 | CLIPVisionModelWithProjection | cpu | torch.float32 | 6.87 | h94/IP-Adapter|sdxl_models/image_encoder|null|null | test1 + | | | | | | test2 +---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + +Other Components: +---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- +ID | Class | Collection +---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- +tokenizer_139917580599504 | CLIPTokenizer | test1 + | | test2 +scheduler_139916266559408 | EulerDiscreteScheduler | test1 + | | test2 +tokenizer_2_139915763443904 | CLIPTokenizer | test1 + | | test2 +---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + +Additional Component Info: +================================================== +``` + + +## Automatic Memory Management + +The Components Manager provides a global offloading strategy across all models, regardless of which pipeline is using them: + +```py +comp.enable_auto_cpu_offload(device="cuda") +``` + +When enabled, all models start on CPU. The manager moves models to the device right before they're used and moves other models back to CPU when GPU memory runs low. You can set your own rules for which models to offload first. This works smoothly as you add or remove components. Once it's on, you don't need to worry about device placement - you can focus on your workflow. + + + +## Practical Example: Building Modular Workflows with Component Reuse + +Now that we've covered the basics of the Components Manager, let's walk through a practical example that shows how to build workflows in a modular setting and use the Components Manager to reuse components across multiple pipelines. This example demonstrates the true power of Modular Diffusers by working with multiple pipelines that can share components. + +In this example, we'll generate latents from a text-to-image pipeline, then refine them with an image-to-image pipeline. We will also use Lora and IP-Adapter. + +Let's create a modular text-to-image workflow by separating it into three components: `text_blocks` for encoding prompts, `t2i_blocks` for generating latents, and `decoder_blocks` for creating final images. + +```py +import torch +from diffusers.modular_pipelines import SequentialPipelineBlocks +from diffusers.modular_pipelines.stable_diffusion_xl import ALL_BLOCKS + +# Create modular blocks and separate text encoding and decoding steps +t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(ALL_BLOCKS["text2img"]) +text_blocks = t2i_blocks.sub_blocks.pop("text_encoder") +decoder_blocks = t2i_blocks.sub_blocks.pop("decode") +``` + +Now we will convert them into runnalbe pipelines and set up the Components Manager with auto offloading and organize components under a "t2i" collection: + +```py +from diffusers import ComponentsManager, ModularPipeline + +# Set up Components Manager with auto offloading +components = ComponentsManager() +components.enable_auto_cpu_offload(device="cuda") + +# Create pipelines and load components +t2i_repo = "YiYiXu/modular-demo-auto" +t2i_loader_pipe = ModularPipeline.from_pretrained(t2i_repo, components_manager=components, collection="t2i") + +text_node = text_blocks.init_pipeline(t2i_repo, components_manager=components) +decoder_node = decoder_blocks.init_pipeline(t2i_repo, components_manager=components) +t2i_pipe = t2i_blocks.init_pipeline(t2i_repo, components_manager=components) +``` + +Load all components into the Components Manager under the "t2i" collection: + +```py +# Load all components (including IP-Adapter and ControlNet for later use) +t2i_loader_pipe.load_components(names=t2i_loader_pipe.pretrained_component_names, torch_dtype=torch.float16) +``` + +Now distribute the loaded components to each pipeline: + +```py +# Get VAE for decoder (using get_one since there's only one) +vae = components.get_one(load_id="SG161222/RealVisXL_V4.0|vae|null|null") +decoder_node.update_components(vae=vae) + +# Get text components for text node (using get_components_by_names for multiple components) +text_components = components.get_components_by_names(text_node.null_component_names) +text_node.update_components(**text_components) + +# Get remaining components for t2i pipeline +t2i_components = components.get_components_by_names(t2i_pipe.null_component_names) +t2i_pipe.update_components(**t2i_components) +``` + +Now we can generate images using our modular workflow: + +```py +# Generate text embeddings +prompt = "an astronaut" +text_embeddings = text_node(prompt=prompt, output=["prompt_embeds","negative_prompt_embeds", "pooled_prompt_embeds", "negative_pooled_prompt_embeds"]) + +# Generate latents and decode to image +generator = torch.Generator(device="cuda").manual_seed(0) +latents_t2i = t2i_pipe(**text_embeddings, num_inference_steps=25, generator=generator, output="latents") +image = decoder_node(latents=latents_t2i, output="images")[0] +image.save("modular_part2_t2i.png") +``` + +Let's add a LoRA: + +```py +# Load LoRA weights - only the UNet gets the adapter +>>> t2i_loader_pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy_face") +>>> components +Components: +============================================================================================================================================================ +... +Additional Component Info: +================================================== + +unet: + Adapters: ['toy_face'] +``` + +You can see that the Components Manager tracks adapters metadata for all models it manages, and in our case, only Unet has lora loaded. This means we can reuse existing text embeddings. + +```py +# Generate with LoRA (reusing existing text embeddings) +generator = torch.Generator(device="cuda").manual_seed(0) +latents_lora = t2i_pipe(**text_embeddings, num_inference_steps=25, generator=generator, output="latents") +image = decoder_node(latents=latents_lora, output="images")[0] +image.save("modular_part2_lora.png") +``` + + +Now let's create a refiner pipeline that reuses components from our text-to-image workflow: + +```py +# Create refiner blocks (removing image_encoder and decode since we work with latents) +refiner_blocks = SequentialPipelineBlocks.from_blocks_dict(ALL_BLOCKS["img2img"]) +refiner_blocks.sub_blocks.pop("image_encoder") +refiner_blocks.sub_blocks.pop("decode") + +# Create refiner pipeline with different repo and collection +refiner_repo = "YiYiXu/modular_refiner" +refiner_pipe = refiner_blocks.init_pipeline(refiner_repo, components_manager=components, collection="refiner") +``` + +We pass the **same Components Manager** (`components`) to the refiner pipeline, but with a **different collection** (`"refiner"`). This allows the refiner to access and reuse components from the "t2i" collection while organizing its own components (like the refiner UNet) under the "refiner" collection. + +```py +# Load only the refiner UNet (different from t2i UNet) +refiner_pipe.load_components(names="unet", torch_dtype=torch.float16) + +# Reuse components from t2i pipeline using pattern matching +reuse_components = components.search_components("text_encoder_2|scheduler|vae|tokenizer_2") +refiner_pipe.update_components(**reuse_components) +``` + +When we reuse components from the "t2i" collection, they automatically get added to the "refiner" collection as well. You can verify this by checking the Components Manager - you'll see components like `vae`, `scheduler`, etc. listed under both collections, indicating they're shared between workflows. + +Now we can refine any of our generated latents: + +```py +# Refine all our different latents +refined_latents = refiner_pipe(image_latents=latents_t2i, prompt=prompt, num_inference_steps=10, output="latents") +refined_image = decoder_node(latents=refined_latents, output="images")[0] +refined_image.save("modular_part2_t2i_refine_out.png") + +refined_latents = refiner_pipe(image_latents=latents_lora, prompt=prompt, num_inference_steps=10, output="latents") +refined_image = decoder_node(latents=refined_latents, output="images")[0] +refined_image.save("modular_part2_lora_refine_out.png") +``` + + +Here are the results from our modular pipeline examples. + +#### Base Text-to-Image Generation +| Base Text-to-Image | Base Text-to-Image (Refined) | +|-------------------|------------------------------| +| ![Base T2I](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/modular_part2_t2i.png) | ![Base T2I Refined](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/modular_part2_t2i_refine_out.png) | + +#### LoRA +| LoRA | LoRA (Refined) | +|-------------------|------------------------------| +| ![LoRA](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/modular_part2_lora.png) | ![LoRA Refined](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/modular_part2_lora_refine_out.png) | + diff --git a/docs/source/en/modular_diffusers/getting_started.md b/docs/source/en/modular_diffusers/getting_started.md index fce46527d123..1eb39c04d432 100644 --- a/docs/source/en/modular_diffusers/getting_started.md +++ b/docs/source/en/modular_diffusers/getting_started.md @@ -1219,402 +1219,4 @@ image = pipeline( image.save("modular_ipa_out.png") ``` -## Building Advanced Workflows: The Modular Way - -We've learned the basic components of the Modular Diffusers System. Now let's tie everything together with more practical example that demonstrates the true power of Modular Diffusers: working between with multiple pipelines that can share components. - -In this example, we'll generate latents from a text-to-image pipeline, then refine them with an image-to-image pipeline. We will use IP-adapter, LoRA, and ControlNet. - -### Base Text-to-Image - -Let's setup the text-to-image workflow. Instead of putting all blocks into one complete pipeline, we'll create separate `text_blocks` for encoding prompts, `t2i_blocks` for generating latents, and `decoder_blocks` for creating final images. - - -```py -import torch -from diffusers.modular_pipelines import SequentialPipelineBlocks -from diffusers.modular_pipelines.stable_diffusion_xl import ALL_BLOCKS - -# create t2i blocks and then pop out the text_encoder step and decoder step so that we can use them in standalone manner -t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(ALL_BLOCKS["text2img"]) -text_blocks = t2i_blocks.sub_blocks.pop("text_encoder") -decoder_blocks = t2i_blocks.sub_blocks.pop("decode") -``` - -Next, convert them into runnable pipelines. We'll use a Components Manager with auto offloading strategy. - -**Components Manager**: Create one manager and pass it to `init_pipeline` along with a collection name. All models loaded by that pipeline will be added to the manager under that collection. - -**Auto Offloading**: All components are placed on CPU and only moved to device right before their forward pass. The manager monitors device memory and may move components off-device to make space for new ones. Unlike `DiffusionPipeline.enable_model_cpu_offload()`, this works across all components in the manager and all your workflows. - - -```py -from diffusers import ComponentsManager -# Set up component manager and turn on the offloading -components = ComponentsManager() -components.enable_auto_cpu_offload(device="cuda") -``` - -Since we have a modular setup where different pipelines may share components, we recommend using a seperate `ModularPipeline` to load components all at once and add them to each pipeline with `update_components()`. - - -```py -from diffusers import ModularPipeline -t2i_repo = "YiYiXu/modular-demo-auto" -t2i_loader_pipe = ModularPipeline.from_pretrained(t2i_repo, components_manager=components, collection="t2i") - -text_node = text_blocks.init_pipeline(t2i_repo, components_manager=components) -decoder_node = decoder_blocks.init_pipeline(t2i_repo, components_manager=components) -t2i_pipe = t2i_blocks.init_pipeline(t2i_repo, components_manager=components) -``` - -We'll load components in `t2i_loader_pipe`. You can get the list of all loadable components from loader's `pretrained_component_names` property. - -```py ->>> t2i_loader_pipe.pretrained_component_names -['controlnet', 'image_encoder', 'scheduler', 'text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'unet', 'vae'] -``` - -It include controlnet and image_encoder for ip-adapter that we don't need now. But I'll load them anyway since they'll stay on CPU and I might use them later. But you can choose what to load in the `names` argument. - -```py -import torch -# inspect before you load -# t2i_loader -t2i_loader_pipe.load_components(names=t2i_loader_pipe.pretrained_component_names, torch_dtype=torch.float16) -``` -All the models are registered to components manager under the collection "t2i". - -```py ->>> components -Components: -============================================================================================================================================================ -Models: ------------------------------------------------------------------------------------------------------------------------------------------------------------- -Name | Class | Device: act(exec)| Dtype | Size (GB)| Load ID | Collection ------------------------------------------------------------------------------------------------------------------------------------------------------------- -vae | AutoencoderKL | cpu(cuda:0) | torch.float16| 0.16 | SG161222/RealVisXL_V4.0|vae|null|null | t2i -image_encoder | CLIPVisionModelWithProjection| cpu(cuda:0) | torch.float16| 3.44 | h94/IP-Adapter|sdxl_models/image_encoder|null|null | t2i -text_encoder | CLIPTextModel | cpu(cuda:0) | torch.float16| 0.23 | SG161222/RealVisXL_V4.0|text_encoder|null|null | t2i -unet | UNet2DConditionModel | cpu(cuda:0) | torch.float16| 4.78 | SG161222/RealVisXL_V4.0|unet|null|null | t2i -text_encoder_2 | CLIPTextModelWithProjection | cpu(cuda:0) | torch.float16| 1.29 | SG161222/RealVisXL_V4.0|text_encoder_2|null|null | t2i -controlnet | ControlNetModel | cpu(cuda:0) | torch.float16| 2.33 | diffusers/controlnet-canny-sdxl-1.0|null|null|null | t2i ------------------------------------------------------------------------------------------------------------------------------------------------------------- - -Other Components: ------------------------------------------------------------------------------------------------------------------------------------------------------------- -Name | Class | Collection ------------------------------------------------------------------------------------------------------------------------------------------------------------- -tokenizer_2 | CLIPTokenizer | t2i -tokenizer | CLIPTokenizer | t2i -scheduler | EulerDiscreteScheduler | t2i ------------------------------------------------------------------------------------------------------------------------------------------------------------- - -Additional Component Info: -================================================== -``` - -Let's add the loaded components to each pipeline. We'll follow this pattern for each pipeline: -1. Check what components the pipeline needs: inspect `pipeline` or use `pipeline.null_component_names` -2. Get them from the components manager: use its `search_models()`/`get_one`/`get_components_from_names` method -3. Update the pipeline: `pipeline.update_components()` -4. Verify the components are loaded correctly: inspect `pipeline` as well as components manager - -We will start with `decoder_node`. First, check what components it needs: - -```py ->>> decoder_node.null_component_names -['vae'] -``` -The pipeline only needs a `vae`. Looking at the components manager table, there's only one VAE available: - -``` -Name | Class | Device: act(exec)| Dtype | Size (GB)| Load ID | Collection ----------------------------------------------------------------------------------------------------------------------- -vae | AutoencoderKL| cpu(cuda:0) | torch.float16| 0.16 | SG161222/RealVisXL_V4.0|vae|null|null | t2i -``` -Since there's only one VAE, we can get it using its unique Load ID: - -```py -vae = components.get_one(load_id="SG161222/RealVisXL_V4.0|vae|null|null") -decoder_node.update_components(vae=vae) -``` - -Verify it's correctly loaded: - -```py -decoder_node -``` -Now let's do the same for `text_node`. Get the list of components the pipeline needs to load: - -```py ->>> text_node.null_component_names -['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2'] -``` -Pass the list directly to the components manager to get the components and add it to the pipeline - -```py -text_components = components.get_components_by_names(text_node.null_component_names) -# Add components to pipeline -text_node.update_components(**text_components) - -# Verify components are loaded -assert not text_node.null_component_names -text_node -``` - -Finally, let's set up `t2i_pipe`: - -```py - -# Get unet & scheduler from components manager and add to pipeline -comps = components.get_components_by_names(t2i_pipe.null_component_names) -t2i_pipe.update_components(**comps) - -# Verify everything is loaded -assert not t2i_pipe.null_component_names -t2i_pipe - -# Verify components manager hasn't changed (we only reused existing components) -components -``` - -We can start to generate an image with the t2i pipeline. - -First to run the prompt through text_node to get prompt embeddings - - - -💡 don't forget to `text_node.doc` to find out what outputs are available and set the `output` argument accordingly - - - -```py -prompt = "an astronaut" -text_embeddings = text_node(prompt=prompt, output=["prompt_embeds","negative_prompt_embeds", "pooled_prompt_embeds", "negative_pooled_prompt_embeds"]) -``` - -Now generate latents with t2i pipeline and then decode with decoder. - - -```py -generator = torch.Generator(device="cuda").manual_seed(0) -latents_t2i = t2i_pipe(**text_embeddings, num_inference_steps=25, generator=generator, output="latents") -image = decoder_node(latents=latents_t2i, output="images")[0] -image.save("modular_part2_t2i.png") - -``` - -### Lora - -Now let's add a LoRA to our pipeline. With the modular approach we will be able to reuse intermediate outputs from blocks that otherwise needs to be re-run. Let's load the LoRA weights and see what happens: - -```py -t2i_loader_pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy_face") -components -``` -Notice that the "Additional Component Info" section shows that only the `unet` component has the LoRA adapter loaded. This means we can skip the text encoding step and reuse the existing embeddings, making the generation much faster. - -```out -Components: -============================================================================================================================================================ -... -Additional Component Info: -================================================== - -unet: - Adapters: ['toy_face'] -``` - - - - -🔍 Alternatively, you can find a component's ID and then use `get_model_info` to get detailed metadata about that component: - -```py -id = components.get_ids("unet")[0] -components.get_model_info(id) -# {'model_id': 'unet_6c2b839d-ec39-4ce9-8741-333ba6d25932', 'added_time': 1751101289.203884, 'collection': 't2i', 'class_name': 'UNet2DConditionModel', 'size_gb': 4.940812595188618, 'adapters': ['toy_face'], 'has_hook': True, 'execution_device': device(type='cuda', index=0)} -``` - - - -```py -generator = torch.Generator(device="cuda").manual_seed(0) -latents_lora = t2i_pipe(**text_embeddings, num_inference_steps=25, generator=generator, output="latents") -image = decoder_node(latents=latents_lora, output="images")[0] -image.save("modular_part2_lora.png") -``` - -### IP-adapter - -IP-adapter can also be used as a standalone pipeline. We can generate the embeddings once and reuse them for different workflows. - -```py -from diffusers.utils import load_image - -ipa_blocks = ALL_BLOCKS["ip_adapter"]["ip_adapter"]() -ipa_node = ipa_blocks.init_pipeline(t2i_repo, components_manager=components) -comps = components.get_components_by_names(ipa_node.loader.null_component_names) -ipa_node.update_components(**comps) - -t2i_loader_pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin") -t2i_loader_pipe.set_ip_adapter_scale(0.6) - -# check it's correctly loaded -assert not ipa_node.null_component_names -ipa_node -# find out inputs/outputs -print(ipa_node.doc) - -ip_adapter_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img5.png") -ipa_embeddings = ipa_node(ip_adapter_image=ip_adapter_image, output=["ip_adapter_embeds","negative_ip_adapter_embeds"]) - -generator = torch.Generator(device="cuda").manual_seed(0) -latents_ipa = t2i_pipe(**text_embeddings, **ipa_embeddings, num_inference_steps=25, generator=generator, output="latents") - -image = decoder_node(latents=latents_ipa, output="images")[0] -image.save("modular_part2_lora_ipa.png") -``` - -### ControlNet - -We can create a new ControlNet workflow by modifying the pipeline blocks, reusing components as much as possible, and see how it affects the generation. - -We want to use a different ControlNet from the one that's already loaded. - -```py -from diffusers import ComponentSpec, ControlNetModel -control_blocks = ALL_BLOCKS["controlnet"]["denoise"]() -# update the t2i_blocks and create pipeline -t2i_blocks.sub_blocks["denoise"] = control_blocks -t2i_control_pipe = t2i_blocks.init_pipeline(t2i_repo, components_manager=components) - -# fetch the controlnet_pose seperately since we need to change name when adding it to the pipeline -controlnet_spec = ComponentSpec(name="controlnet_pose", type_hint=ControlNetModel, repo="thibaud/controlnet-openpose-sdxl-1.0") -controlnet = controlnet_spec.load(torch_dtype=torch.float16) -t2i_control_pipe.update_components(controlnet=controlnet) - -# fetch the rest of the components from the components manager -comps = components.get_components_by_names(t2i_control_pipe.loader.null_component_names) -t2i_control_pipe.update_components(**comps) - -control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/controlnet/person_pose.png") -generator = torch.Generator(device="cuda").manual_seed(0) -latents_control = t2i_control_pipe(**text_embeddings, **ipa_embeddings, control_image=control_image, num_inference_steps=25, generator=generator, output="latents") - -image = decoder_node(latents=latents_control, output="images")[0] -image.save("modular_part2_lora_ipa_control.png") -``` - - -Now set up refiner workflow. For refiner blocks, we removed `image_encoder` since the refiner works with latents directly, and `decoder` since we already have a dedicated one. We keep `text_encoder` because SDXL refiner encodes text prompts differently from the text-to-image pipeline, so we cannot share it. - -```py -# Create a refiner blocks -# - removing image_encoder a since we'll use latents from t2i -# - removing decode since we already created a seperate decoder_block -refiner_blocks = SequentialPipelineBlocks.from_blocks_dict(ALL_BLOCKS["img2img"]) -refiner_blocks.sub_blocks.pop("image_encoder") -refiner_blocks.sub_blocks.pop("decode") -``` - -### Refiner - -Create refiner pipeline. refiner has a different unet and use only one text_encoder so it is hosted in a different repo. We pass the same components manager to refiner pipeline, along with a unique "refiner" collection. - -```py -refiner_repo = "YiYiXu/modular_refiner" -refiner_pipe = refiner_blocks.init_pipeline(refiner_repo, components_manager=components, collection="refiner") -``` - - -We want to reuse components from the t2i pipeline in the refiner as much as possible. First, let's check the loading status of the refiner pipeline to understand what components are needed: - -```py ->>> refiner_pipe -``` - -Looking at the loader output, you can see that `text_encoder` and `tokenizer` have empty loading spec maps (their `repo` fields are `null`), this is because refiner pipeline does not use these two components so they are not listed in the `modular_model_index.json` in `refiner_repo`. The `unet` is different from the one we loaded for text-to-image. The remaining components: `vae`, `text_encoder_2`, `tokenizer_2`, and `scheduler` are already available in the t2i collection, we can reuse them instead of loading duplicates. - -```py -refiner_pipe.load_components(names="unet", torch_dtype=torch.float16) - -# verify loaded correctly -refiner_pipe - -# veryfiy registered to components manager under refiner -components -``` - -Now let's reuse the components from the t2i pipeline in the refiner. We use the`|` to select multiple components from components manager at once: - -```py -# Reuse components from t2i pipeline (select everything at once) -reuse_components = components.search_components("text_encoder_2|scheduler|vae|tokenizer_2") -refiner_pipe.update_components(**reuse_components) -``` - -You'll see warnings indicating that these components already exist in the components manager: - -```out -component 'text_encoder_2' already exists as 'text_encoder_2_238ae9a7-c864-4837-a8a2-f58ed753b2d0' -component 'tokenizer_2' already exists as 'tokenizer_2_b795af3d-f048-4b07-a770-9e8237a2be2d' -component 'scheduler' already exists as 'scheduler_e3435f63-266a-4427-9383-eb812e830fe8' -component 'vae' already exists as 'vae_357eee6a-4a06-46f1-be83-494f7d60ca69' -``` - -These warnings are expected and indicate that the components manager is correctly identifying that these components are already loaded. The system will reuse the existing components rather than creating duplicates. - -Let's check the components manager again to see the updated state. You should see `text_encoder_2`, `vae`, `tokenizer_2`, and `scheduler` now appear under both "t2i" and "refiner" collections. - -Now let's refine! - -```py -# refine the latents from base text-to-image workflow -refined_latents = refiner_pipe(image_latents=latents_t2i, prompt=prompt, num_inference_steps=10, output="latents") -refined_image = decoder_node(latents=refined_latents, output="images")[0] -refined_image.save("modular_part2_t2i_refine_out.png") - -# refine the latents from the text-to-image lora workflow -refined_latents = refiner_pipe(image_latents=latents_lora, prompt=prompt, num_inference_steps=10, output="latents") -refined_image = decoder_node(latents=refined_latents, output="images")[0] -refined_image.save("modular_part2_lora_refine_out.png") - -# refine the latents from the text-to-image + lora + ip-adapter workflow -refined_latents = refiner_pipe(image_latents=latents_ipa, prompt=prompt, num_inference_steps=10, output="latents") -refined_image = decoder_node(latents=refined_latents, output="images")[0] -refined_image.save("modular_part2_ipa_refine_out.png") - -# refine the latents from the text-to-image + lora + ip-adapter + controlnet workflow -refined_latents = refiner_pipe(image_latents=latents_control, prompt=prompt, num_inference_steps=10, output="latents") -refined_image = decoder_node(latents=refined_latents, output="images")[0] -refined_image.save("modular_part2_control_refine_out.png") -``` - - -### Results - -Here are the results from our modular pipeline examples. - -#### Base Text-to-Image Generation -| Base Text-to-Image | Base Text-to-Image (Refined) | -|-------------------|------------------------------| -| ![Base T2I](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/modular_part2_t2i.png) | ![Base T2I Refined](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/modular_part2_t2i_refine_out.png) | - -#### LoRA -| LoRA | LoRA (Refined) | -|-------------------|------------------------------| -| ![LoRA](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/modular_part2_lora.png) | ![LoRA Refined](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/modular_part2_lora_refine_out.png) | - -#### LoRA + IP-Adapter -| LoRA + IP-Adapter | LoRA + IP-Adapter (Refined) | -|-------------------|------------------------------| -| ![LoRA + IP-Adapter](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/modular_part2_ipa.png) | ![LoRA + IP-Adapter Refined](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/modular_part2_ipa_refine_out.png) | - -#### ControlNet + LoRA + IP-Adapter -| ControlNet + LoRA + IP-Adapter | ControlNet + LoRA + IP-Adapter (Refined) | -|-------------------|------------------------------| -| ![ControlNet + LoRA + IP-Adapter](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/modular_part2_control.png) | ![ControlNet + LoRA + IP-Adapter Refined](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/modular_part2_control_refine_out.png) | - From be5e10ae611948b7725e2eb3520d7c0077e8ca10 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 8 Jul 2025 09:46:52 +0530 Subject: [PATCH 158/170] Copied-from implementation of PAG-guider (#11882) * update * fix --- .../guiders/perturbed_attention_guidance.py | 161 ++++++++++++++++-- .../modular_pipelines/modular_pipeline.py | 2 +- 2 files changed, 145 insertions(+), 18 deletions(-) diff --git a/src/diffusers/guiders/perturbed_attention_guidance.py b/src/diffusers/guiders/perturbed_attention_guidance.py index 3045f2feaae2..1b2256732ffc 100644 --- a/src/diffusers/guiders/perturbed_attention_guidance.py +++ b/src/diffusers/guiders/perturbed_attention_guidance.py @@ -12,18 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Union +import math +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union + +import torch from ..configuration_utils import register_to_config -from ..hooks import LayerSkipConfig +from ..hooks import HookRegistry, LayerSkipConfig +from ..hooks.layer_skip import _apply_layer_skip_hook from ..utils import get_logger -from .skip_layer_guidance import SkipLayerGuidance +from .guider_utils import BaseGuidance, rescale_noise_cfg + + +if TYPE_CHECKING: + from ..modular_pipelines.modular_pipeline import BlockState logger = get_logger(__name__) # pylint: disable=invalid-name -class PerturbedAttentionGuidance(SkipLayerGuidance): +class PerturbedAttentionGuidance(BaseGuidance): """ Perturbed Attention Guidance (PAG): https://huggingface.co/papers/2403.17377 @@ -36,7 +44,7 @@ class PerturbedAttentionGuidance(SkipLayerGuidance): Additional reading: - [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507) - PAG is implemented as a specialization of the SkipLayerGuidance due to similarities in the configuration parameters + PAG is implemented with similar implementation to SkipLayerGuidance due to overlap in the configuration parameters and implementation details. Args: @@ -75,6 +83,8 @@ class PerturbedAttentionGuidance(SkipLayerGuidance): # complex to support joint latent conditioning in a model-agnostic manner without specializing the implementation # for each model architecture. + _input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"] + @register_to_config def __init__( self, @@ -89,6 +99,15 @@ def __init__( start: float = 0.0, stop: float = 1.0, ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.skip_layer_guidance_scale = perturbed_guidance_scale + self.skip_layer_guidance_start = perturbed_guidance_start + self.skip_layer_guidance_stop = perturbed_guidance_stop + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + if perturbed_guidance_config is None: if perturbed_guidance_layers is None: raise ValueError( @@ -130,15 +149,123 @@ def __init__( config.skip_attention_scores = True config.skip_ff = False - super().__init__( - guidance_scale=guidance_scale, - skip_layer_guidance_scale=perturbed_guidance_scale, - skip_layer_guidance_start=perturbed_guidance_start, - skip_layer_guidance_stop=perturbed_guidance_stop, - skip_layer_guidance_layers=perturbed_guidance_layers, - skip_layer_config=perturbed_guidance_config, - guidance_rescale=guidance_rescale, - use_original_formulation=use_original_formulation, - start=start, - stop=stop, - ) + self.skip_layer_config = perturbed_guidance_config + self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))] + + # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_models + def prepare_models(self, denoiser: torch.nn.Module) -> None: + self._count_prepared += 1 + if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1: + for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config): + _apply_layer_skip_hook(denoiser, config, name=name) + + # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.cleanup_models + def cleanup_models(self, denoiser: torch.nn.Module) -> None: + if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1: + registry = HookRegistry.check_if_exists_or_initialize(denoiser) + # Remove the hooks after inference + for hook_name in self._skip_layer_hook_names: + registry.remove_hook(hook_name, recurse=True) + + # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_inputs + def prepare_inputs( + self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None + ) -> List["BlockState"]: + if input_fields is None: + input_fields = self._input_fields + + if self.num_conditions == 1: + tuple_indices = [0] + input_predictions = ["pred_cond"] + elif self.num_conditions == 2: + tuple_indices = [0, 1] + input_predictions = ( + ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"] + ) + else: + tuple_indices = [0, 1, 0] + input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"] + data_batches = [] + for i in range(self.num_conditions): + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i]) + data_batches.append(data_batch) + return data_batches + + # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.forward + def forward( + self, + pred_cond: torch.Tensor, + pred_uncond: Optional[torch.Tensor] = None, + pred_cond_skip: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + pred = None + + if not self._is_cfg_enabled() and not self._is_slg_enabled(): + pred = pred_cond + elif not self._is_cfg_enabled(): + shift = pred_cond - pred_cond_skip + pred = pred_cond if self.use_original_formulation else pred_cond_skip + pred = pred + self.skip_layer_guidance_scale * shift + elif not self._is_slg_enabled(): + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + else: + shift = pred_cond - pred_uncond + shift_skip = pred_cond - pred_cond_skip + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred, {} + + @property + # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.is_conditional + def is_conditional(self) -> bool: + return self._count_prepared == 1 or self._count_prepared == 3 + + @property + # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.num_conditions + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_cfg_enabled(): + num_conditions += 1 + if self._is_slg_enabled(): + num_conditions += 1 + return num_conditions + + # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_cfg_enabled + def _is_cfg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close + + # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_slg_enabled + def _is_slg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps) + skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps) + is_within_range = skip_start_step < self._step < skip_stop_step + + is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0) + + return is_within_range and not is_zero diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 5440e5e5a6ff..57af0f220765 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -335,7 +335,7 @@ def init_pipeline( pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, components_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, - ): + ) -> "ModularPipeline": """ create a ModularPipeline, optionally accept modular_repo to load from hub. """ From e6ffde2936a43d6375de77ae3e705fea94240fa1 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Mon, 7 Jul 2025 18:25:31 -1000 Subject: [PATCH 159/170] Apply suggestions from code review Co-authored-by: Aryan --- docs/source/en/modular_diffusers/getting_started.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/en/modular_diffusers/getting_started.md b/docs/source/en/modular_diffusers/getting_started.md index 1eb39c04d432..2e0cb2fd08cf 100644 --- a/docs/source/en/modular_diffusers/getting_started.md +++ b/docs/source/en/modular_diffusers/getting_started.md @@ -692,9 +692,9 @@ ComponentSpec( ### Customizing Guidance Techniques -Guiders are guidance techniques that can be applied during the denoising process to improve generation quality, control, and adherence to prompts. They work by modifying the noise predictions or model behavior to steer the generation process in desired directions. In diffusers, guiders are implemented as subclasses of `BaseGuidance` and can be easily integrated into modular pipelines, providing a flexible way to enhance generation quality without modifying the underlying diffusion models. +Guiders are implementations of different [classifier-free guidance](https://huggingface.co/papers/2207.12598) techniques that can be applied during the denoising process to improve generation quality, control, and adherence to prompts. They work by steering the model predictions towards desired directions and away from undesired directions. In diffusers, guiders are implemented as subclasses of `BaseGuidance`. They can easily be integrated into modular pipelines and provide a flexible way to enhance generation quality without modifying the underlying diffusion models. -**ClassifierFreeGuidance (CFG)** is the first and most common guidance technique, used in all our standard pipelines. But we offer many more guidance techniques beyond CFG, including **PerturbedAttentionGuidance (PAG)**, **SkipLayerGuidance (SLG)**, **SmoothedEnergyGuidance (SEG)**, and others that can provide even better results for specific use cases. +**ClassifierFreeGuidance (CFG)** is the first and most common guidance technique, used in all our standard pipelines. We also offer many other guidance techniques from the latest research in this area - **PerturbedAttentionGuidance (PAG)**, **SkipLayerGuidance (SLG)**, **SmoothedEnergyGuidance (SEG)**, and others that can provide better results for specific use cases. This section demonstrates how to use guiders using the component updating methods we just learned. Since `BaseGuidance` components are stateless (similar to schedulers), they are typically created with default configurations during pipeline initialization using `default_creation_method='from_config'`. This means they don't require loading specs from the repository - you won't see guider listed in `modular_model_index.json` files. @@ -756,7 +756,7 @@ ClassifierFreeGuidance { #### Switch to a Different Guider Type -Since guiders are `from_config` components (ConfigMixin objects), you can pass guider objects directly to switch between different guidance techniques: +Switching between guidance techniques is as simple as passing a guider object of that technique: ```py from diffusers import LayerSkipConfig, PerturbedAttentionGuidance @@ -776,7 +776,7 @@ ModularPipeline.update_components: adding guider with new type: PerturbedAttenti 💡 **Component Loading Methods**: -- For `from_config` components (like guiders, schedulers): You can pass the object directly OR pass a ComponentSpec directly (which calls `create()` under the hood) +- For `from_config` components (like guiders, schedulers): You can pass an object of required type OR pass a ComponentSpec directly (which calls `create()` under the hood) - For `from_pretrained` components (like models): You must use ComponentSpec to ensure proper tagging and loading @@ -915,7 +915,7 @@ Of course, you can also directly modify the `modular_model_index.json` to add a - **SmoothedEnergyGuidance (SEG)**: Helps with energy distribution smoothing - **AdaptiveProjectedGuidance (APG)**: Adaptive guidance that projects predictions for better quality -Experiment with different techniques and parameters to find what works best for your specific use case! +Experiment with different techniques and parameters to find what works best for your specific use case! Additionally, you can write your own guider implementations, for example, CFG Zero* combined with Skip Layer Guidance, and they should be compatible out-of-the-box with modular diffusers! From 5f3ebef0d7ca70510746ef0c759aa6f05725ef2d Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 8 Jul 2025 06:29:47 +0200 Subject: [PATCH 160/170] update remove duplicated config for pag, and remove the description of all the guiders --- .../en/modular_diffusers/getting_started.md | 28 ++----------------- 1 file changed, 3 insertions(+), 25 deletions(-) diff --git a/docs/source/en/modular_diffusers/getting_started.md b/docs/source/en/modular_diffusers/getting_started.md index 2e0cb2fd08cf..b38276fb5e0d 100644 --- a/docs/source/en/modular_diffusers/getting_started.md +++ b/docs/source/en/modular_diffusers/getting_started.md @@ -805,23 +805,6 @@ PerturbedAttentionGuidance { "perturbed_guidance_scale": 2.5, "perturbed_guidance_start": 0.01, "perturbed_guidance_stop": 0.2, - "skip_layer_config": [ - { - "dropout": 1.0, - "fqn": "mid_block.attentions.0.transformer_blocks", - "indices": [ - 2, - 9 - ], - "skip_attention": false, - "skip_attention_scores": true, - "skip_ff": false - } - ], - "skip_layer_guidance_layers": null, - "skip_layer_guidance_scale": 2.5, - "skip_layer_guidance_start": 0.01, - "skip_layer_guidance_stop": 0.2, "start": 0.0, "stop": 1.0, "use_original_formulation": false @@ -833,7 +816,7 @@ The component spec has also been updated to reflect the new guider type: ```py >>> t2i_pipeline.get_component_spec("guider") -ComponentSpec(name='guider', type_hint=, description=None, config=FrozenDict([('guidance_scale', 5.0), ('perturbed_guidance_scale', 2.5), ('perturbed_guidance_start', 0.01), ('perturbed_guidance_stop', 0.2), ('perturbed_guidance_layers', None), ('perturbed_guidance_config', LayerSkipConfig(indices=[2, 9], fqn='mid_block.attentions.0.transformer_blocks', skip_attention=False, skip_attention_scores=True, skip_ff=False, dropout=1.0)), ('guidance_rescale', 0.0), ('use_original_formulation', False), ('start', 0.0), ('stop', 1.0), ('_use_default_values', ['use_original_formulation', 'perturbed_guidance_stop', 'stop', 'guidance_rescale', 'start', 'perturbed_guidance_layers', 'perturbed_guidance_start']), ('skip_layer_guidance_scale', 2.5), ('skip_layer_guidance_start', 0.01), ('skip_layer_guidance_stop', 0.2), ('skip_layer_guidance_layers', None), ('skip_layer_config', [LayerSkipConfig(indices=[2, 9], fqn='mid_block.attentions.0.transformer_blocks', skip_attention=False, skip_attention_scores=True, skip_ff=False, dropout=1.0)]), ('_class_name', 'PerturbedAttentionGuidance'), ('_diffusers_version', '0.35.0.dev0')]), repo=None, subfolder=None, variant=None, revision=None, default_creation_method='from_config') +ComponentSpec(name='guider', type_hint=, description=None, config=FrozenDict([('guidance_scale', 5.0), ('perturbed_guidance_scale', 2.5), ('perturbed_guidance_start', 0.01), ('perturbed_guidance_stop', 0.2), ('perturbed_guidance_layers', None), ('perturbed_guidance_config', LayerSkipConfig(indices=[2, 9], fqn='mid_block.attentions.0.transformer_blocks', skip_attention=False, skip_attention_scores=True, skip_ff=False, dropout=1.0)), ('guidance_rescale', 0.0), ('use_original_formulation', False), ('start', 0.0), ('stop', 1.0), ('_use_default_values', ['use_original_formulation', 'perturbed_guidance_stop', 'stop', 'guidance_rescale', 'start', 'perturbed_guidance_layers', 'perturbed_guidance_start']), ('_class_name', 'PerturbedAttentionGuidance'), ('_diffusers_version', '0.35.0.dev0')]), repo=None, subfolder=None, variant=None, revision=None, default_creation_method='from_config') ``` However, the "guider" is still not included in the pipeline config and will not be saved into the `modular_model_index.json` since it remains a `from_config` component: @@ -908,14 +891,9 @@ Of course, you can also directly modify the `modular_model_index.json` to add a -💡 **Guidance Techniques Summary**: -- **ClassifierFreeGuidance (CFG)**: The standard choice, best for general use and prompt adherence -- **PerturbedAttentionGuidance (PAG)**: Enhances attention-based features by perturbing attention mechanisms -- **SkipLayerGuidance (SLG)**: Improves structure and anatomy coherence by skipping specific layers -- **SmoothedEnergyGuidance (SEG)**: Helps with energy distribution smoothing -- **AdaptiveProjectedGuidance (APG)**: Adaptive guidance that projects predictions for better quality +Experiment with different techniques and parameters to find what works best for your specific use case! You can find all the guider class we support [here](TODO: API doc) -Experiment with different techniques and parameters to find what works best for your specific use case! Additionally, you can write your own guider implementations, for example, CFG Zero* combined with Skip Layer Guidance, and they should be compatible out-of-the-box with modular diffusers! +Additionally, you can write your own guider implementations, for example, CFG Zero* combined with Skip Layer Guidance, and they should be compatible out-of-the-box with modular diffusers! From 59abd9514ba57eda0c4018383a5775c660373311 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 8 Jul 2025 06:47:14 +0200 Subject: [PATCH 161/170] add link to components manager doc --- docs/source/en/modular_diffusers/getting_started.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/modular_diffusers/getting_started.md b/docs/source/en/modular_diffusers/getting_started.md index b38276fb5e0d..920d8f8a3572 100644 --- a/docs/source/en/modular_diffusers/getting_started.md +++ b/docs/source/en/modular_diffusers/getting_started.md @@ -377,7 +377,7 @@ This helps you to: 2. Easily reuse components across different pipelines 3. Apply offloading strategies across multiple pipelines -You can read more about Components Manager [here](TODO) +You can read more about [Components Manager](./components_manager.md) From f95c320467450f8e7867b621509ec1b3f64bdf1b Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 8 Jul 2025 07:11:57 +0200 Subject: [PATCH 162/170] addreess more review comments --- .../source/en/modular_diffusers/getting_started.md | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/docs/source/en/modular_diffusers/getting_started.md b/docs/source/en/modular_diffusers/getting_started.md index 920d8f8a3572..d6b9a61a58c6 100644 --- a/docs/source/en/modular_diffusers/getting_started.md +++ b/docs/source/en/modular_diffusers/getting_started.md @@ -21,9 +21,9 @@ With Modular Diffusers, we introduce a unified pipeline system that simplifies h In this guide, we will focus on how to build end-to-end pipelines using blocks we officially support at diffusers 🧨! We will show you how to write your own pipeline blocks and go into more details on how they work under the hood in this [guide](./write_own_pipeline_block.md). For advanced users who want to build complete workflows from scratch, we provide an end-to-end example in the [Developer Guide](./end_to_end.md) that covers everything from writing custom pipeline blocks to deploying your workflow as a UI node. Let's get started! The Modular Diffusers Framework consists of three main components: -- ModularPipelineBlocks -- PipelineState & BlockState -- ModularPipeline +- ModularPipelineBlocks: Building blocks for your workflow, each block defines inputs/outputs and computation steps. These are just definitions and not runnable. +- PipelineState & BlockState: Store and manage data as it flows through the pipeline. +- ModularPipeline: Loads models and runs the computation steps. You convert blocks to pipelines to make them executable. ## ModularPipelineBlocks @@ -68,7 +68,9 @@ StableDiffusionXLTextEncoderStep( ) ``` -More commonly, you can create a `SequentialPipelineBlocks` using a block classes preset from 🧨 Diffusers. +More commonly, you need multiple blocks to build your workflow. You can create a `SequentialPipelineBlocks` using block class presets from 🧨 Diffusers. + +`TEXT2IMAGE_BLOCKS` is a predefined dictionary containing all the blocks needed for a complete text-to-image pipeline (text encoding, denoising, decoding, etc.). We will see more details soon. ```py from diffusers.modular_pipelines import SequentialPipelineBlocks @@ -171,7 +173,7 @@ Note that both the block classes preset and the `sub_blocks` attribute are `Inse **Add a block:** ```py -# Add a block class to the preset +# BLOCKS is a block class preset, you need to add class to it BLOCKS.insert("block_name", BlockClass, index) # Add a block instance to the `sub_blocks` attribute t2i_blocks.sub_blocks.insert("block_name", block_instance, index) @@ -363,6 +365,8 @@ modular_repo_id = "YiYiXu/modular-loader-t2i-0704" t2i_pipeline = t2i_blocks.init_pipeline(modular_repo_id) ``` +The `init_pipeline()` method creates a ModularPipeline and loads component specifications from the repository's `modular_model_index.json` file, but doesn't load the actual models yet. + 💡 We recommend using `ModularPipeline` with Component Manager by passing a `components_manager`: From cb9dca552398c102d544dfde09ec85e4aafff32a Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 8 Jul 2025 20:23:21 +0200 Subject: [PATCH 163/170] add experimental marks to all modular docs --- docs/source/en/modular_diffusers/components_manager.md | 6 ++++++ docs/source/en/modular_diffusers/end_to_end_guide.md | 8 +++++++- docs/source/en/modular_diffusers/getting_started.md | 6 ++++++ .../en/modular_diffusers/write_own_pipeline_block.md | 6 ++++++ 4 files changed, 25 insertions(+), 1 deletion(-) diff --git a/docs/source/en/modular_diffusers/components_manager.md b/docs/source/en/modular_diffusers/components_manager.md index 84ed8e7d26fb..316119409ac9 100644 --- a/docs/source/en/modular_diffusers/components_manager.md +++ b/docs/source/en/modular_diffusers/components_manager.md @@ -12,6 +12,12 @@ specific language governing permissions and limitations under the License. # Components Manager + + +🧪 **Experimental Feature**: This is an experimental feature we are actively developing. The API may be subject to breaking changes. + + + The Components Manager is a central model registry and management system in diffusers. It lets you add models then reuse them across multiple pipelines and workflows. It tracks all models in one place with useful metadata such as model size, device placement and loaded adapters (LoRA, IP-Adapter). It has mechanisms in place to prevent duplicate model instances, enables memory-efficient sharing. Most significantly, it offers offloading that works across pipelines — unlike regular DiffusionPipeline offloading which is limited to one pipeline with predefined sequences, the Components Manager automatically manages your device memory across all your models and workflows. diff --git a/docs/source/en/modular_diffusers/end_to_end_guide.md b/docs/source/en/modular_diffusers/end_to_end_guide.md index 132c4870b770..42852c6e6420 100644 --- a/docs/source/en/modular_diffusers/end_to_end_guide.md +++ b/docs/source/en/modular_diffusers/end_to_end_guide.md @@ -12,6 +12,12 @@ specific language governing permissions and limitations under the License. # End-to-End Developer Guide: Building with Modular Diffusers + + +🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes. + + + In this tutorial we will walk through the process of adding a new pipeline to the modular framework using differential diffusion as our example. We'll cover the complete workflow from implementation to deployment: implementing the new pipeline, ensuring compatibility with existing tools, sharing the code on Hugging Face Hub, and deploying it as a UI node. @@ -164,7 +170,7 @@ We will use this example script: >>> prompt = "a green pear" >>> negative_prompt = "blurry" >>> ->>> image = dd_pipeline.run( +>>> image = dd_pipeline( ... prompt=prompt, ... negative_prompt=negative_prompt, ... num_inference_steps=25, diff --git a/docs/source/en/modular_diffusers/getting_started.md b/docs/source/en/modular_diffusers/getting_started.md index d6b9a61a58c6..4b82d9c85fb6 100644 --- a/docs/source/en/modular_diffusers/getting_started.md +++ b/docs/source/en/modular_diffusers/getting_started.md @@ -12,6 +12,12 @@ specific language governing permissions and limitations under the License. # Getting Started with Modular Diffusers: A Comprehensive Overview + + +🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes. + + + With Modular Diffusers, we introduce a unified pipeline system that simplifies how you work with diffusion models. Instead of creating separate pipelines for each task, Modular Diffusers lets you: **Write Only What's New**: You won't need to write an entire pipeline from scratch every time you have a new use case. You can create pipeline blocks just for your new workflow's unique aspects and reuse existing blocks for existing functionalities. diff --git a/docs/source/en/modular_diffusers/write_own_pipeline_block.md b/docs/source/en/modular_diffusers/write_own_pipeline_block.md index f65af4463ff9..ae2d819e7f61 100644 --- a/docs/source/en/modular_diffusers/write_own_pipeline_block.md +++ b/docs/source/en/modular_diffusers/write_own_pipeline_block.md @@ -12,6 +12,12 @@ specific language governing permissions and limitations under the License. # Writing Your Own Pipeline Blocks + + +🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes. + + + In Modular Diffusers, you build your workflow using `ModularPipelineBlocks`. We support 4 different types of blocks: `PipelineBlock`, `SequentialPipelineBlocks`, `LoopSequentialPipelineBlocks`, and `AutoPipelineBlocks`. Among them, `PipelineBlock` is the most fundamental building block of the whole system - it's like a brick in a Lego system. These blocks are designed to easily connect with each other, allowing for modular construction of creative and potentially very complex workflows. In this tutorial, we will focus on how to write a basic `PipelineBlock` and how it interacts with other components in the system. We will also cover how to connect them together using the multi-blocks: `SequentialPipelineBlocks`, `LoopSequentialPipelineBlocks`, and `AutoPipelineBlocks`. From d27b65411eeb1214d9b55c9e0a39888daa0856c2 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 8 Jul 2025 20:23:44 +0200 Subject: [PATCH 164/170] add more docstrings + experimental marks --- .../modular_pipelines/components_manager.py | 246 +++++++++----- .../modular_pipelines/modular_pipeline.py | 301 +++++++++++++----- .../modular_pipeline_utils.py | 2 +- src/diffusers/modular_pipelines/node_utils.py | 17 +- .../stable_diffusion_xl/modular_pipeline.py | 10 + 5 files changed, 416 insertions(+), 160 deletions(-) diff --git a/src/diffusers/modular_pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py index 1ccc404d1a21..828f53c393bf 100644 --- a/src/diffusers/modular_pipelines/components_manager.py +++ b/src/diffusers/modular_pipelines/components_manager.py @@ -142,7 +142,7 @@ def custom_offload_with_hook( user_hook.attach() return user_hook - +# this is the class that user can customize to implement their own offload strategy class AutoOffloadStrategy: """ Offload strategy that should be used with `CustomOffloadHook` to automatically offload models to the CPU based on @@ -213,7 +213,101 @@ def search_best_candidate(module_sizes, min_memory_offload): return hooks_to_offload +# utils for display component info in a readable format +# TODO: move to a different file +def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]: + """Summarizes a dictionary by finding common prefixes that share the same value. + + For a dictionary with dot-separated keys like: { + 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': [0.6], + 'down_blocks.1.attentions.1.transformer_blocks.1.attn2.processor': [0.6], + 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': [0.3], + } + + Returns a dictionary where keys are the shortest common prefixes and values are their shared values: { + 'down_blocks': [0.6], 'up_blocks': [0.3] + } + """ + # First group by values - convert lists to tuples to make them hashable + value_to_keys = {} + for key, value in d.items(): + value_tuple = tuple(value) if isinstance(value, list) else value + if value_tuple not in value_to_keys: + value_to_keys[value_tuple] = [] + value_to_keys[value_tuple].append(key) + + def find_common_prefix(keys: List[str]) -> str: + """Find the shortest common prefix among a list of dot-separated keys.""" + if not keys: + return "" + if len(keys) == 1: + return keys[0] + + # Split all keys into parts + key_parts = [k.split(".") for k in keys] + + # Find how many initial parts are common + common_length = 0 + for parts in zip(*key_parts): + if len(set(parts)) == 1: # All parts at this position are the same + common_length += 1 + else: + break + + if common_length == 0: + return "" + + # Return the common prefix + return ".".join(key_parts[0][:common_length]) + + # Create summary by finding common prefixes for each value group + summary = {} + for value_tuple, keys in value_to_keys.items(): + prefix = find_common_prefix(keys) + if prefix: # Only add if we found a common prefix + # Convert tuple back to list if it was originally a list + value = list(value_tuple) if isinstance(d[keys[0]], list) else value_tuple + summary[prefix] = value + else: + summary[""] = value # Use empty string if no common prefix + + return summary + + class ComponentsManager: + """ + A central registry and management system for model components across multiple pipelines. + + [`ComponentsManager`] provides a unified way to register, track, and reuse model components + (like UNet, VAE, text encoders, etc.) across different modular pipelines. It includes + features for duplicate detection, memory management, and component organization. + + + + This is an experimental feature and is likely to change in the future. + + + + Example: + ```python + from diffusers import ComponentsManager + + # Create a components manager + cm = ComponentsManager() + + # Add components + cm.add("unet", unet_model, collection="sdxl") + cm.add("vae", vae_model, collection="sdxl") + + # Enable auto offloading + cm.enable_auto_cpu_offload(device="cuda") + + # Retrieve components + unet = cm.get_one(name="unet", collection="sdxl") + ``` + """ + + _available_info_fields = [ "model_id", "added_time", @@ -278,7 +372,19 @@ def _lookup_ids( def _id_to_name(component_id: str): return "_".join(component_id.split("_")[:-1]) - def add(self, name, component, collection: Optional[str] = None): + def add(self, name: str, component: Any, collection: Optional[str] = None): + """ + Add a component to the ComponentsManager. + + Args: + name (str): The name of the component + component (Any): The component to add + collection (Optional[str]): The collection to add the component to + + Returns: + str: The unique component ID, which is generated as "{name}_{id(component)}" where + id(component) is Python's built-in unique identifier for the object + """ component_id = f"{name}_{id(component)}" # check for duplicated components @@ -334,6 +440,12 @@ def add(self, name, component, collection: Optional[str] = None): return component_id def remove(self, component_id: str = None): + """ + Remove a component from the ComponentsManager. + + Args: + component_id (str): The ID of the component to remove + """ if component_id not in self.components: logger.warning(f"Component '{component_id}' not found in ComponentsManager") return @@ -545,6 +657,22 @@ def matches_pattern(component_id, pattern, exact_match=False): return get_return_dict(matches, return_dict_with_names) def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = "cuda", memory_reserve_margin="3GB"): + """ + Enable automatic CPU offloading for all components. + + The algorithm works as follows: + 1. All models start on CPU by default + 2. When a model's forward pass is called, it's moved to the execution device + 3. If there's insufficient memory, other models on the device are moved back to CPU + 4. The system tries to offload the smallest combination of models that frees enough memory + 5. Models stay on the execution device until another model needs memory and forces them off + + Args: + device (Union[str, int, torch.device]): The execution device where models are moved for forward passes + memory_reserve_margin (str): The memory reserve margin to use, default is 3GB. This is the amount of + memory to keep free on the device to avoid running out of memory during + model execution (e.g., for intermediate activations, gradients, etc.) + """ if not is_accelerate_available(): raise ImportError("Make sure to install accelerate to use auto_cpu_offload") @@ -574,6 +702,9 @@ def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = "cuda" self._auto_offload_device = device def disable_auto_cpu_offload(self): + """ + Disable automatic CPU offloading for all components. + """ if self.model_hooks is None: self._auto_offload_enabled = False return @@ -595,13 +726,12 @@ def get_model_info( """Get comprehensive information about a component. Args: - component_id: Name of the component to get info for - fields: Optional field(s) to return. Can be a string for single field or list of fields. + component_id (str): Name of the component to get info for + fields (Optional[Union[str, List[str]]]): Field(s) to return. Can be a string for single field or list of fields. If None, uses the available_info_fields setting. Returns: - Dictionary containing requested component metadata. If fields is specified, returns only those fields. If a - single field is requested as string, returns just that field's value. + Dictionary containing requested component metadata. If fields is specified, returns only those fields. Otherwise, returns all fields. """ if component_id not in self.components: raise ValueError(f"Component '{component_id}' not found in ComponentsManager") @@ -808,15 +938,16 @@ def get_one( load_id: Optional[str] = None, ) -> Any: """ - Get a single component by either: (1) searching name (pattern matching), collection, or load_id. (2) passing in - a component_id Raises an error if multiple components match or none are found. support pattern matching for - name + Get a single component by either: + - searching name (pattern matching), collection, or load_id. + - passing in a component_id + Raises an error if multiple components match or none are found. Args: - component_id: Optional component ID to get - name: Component name or pattern - collection: Optional collection to filter by - load_id: Optional load_id to filter by + component_id (Optional[str]): Optional component ID to get + name (Optional[str]): Component name or pattern + collection (Optional[str]): Optional collection to filter by + load_id (Optional[str]): Optional load_id to filter by Returns: A single component @@ -847,6 +978,13 @@ def get_one( def get_ids(self, names: Union[str, List[str]] = None, collection: Optional[str] = None): """ Get component IDs by a list of names, optionally filtered by collection. + + Args: + names (Union[str, List[str]]): List of component names + collection (Optional[str]): Optional collection to filter by + + Returns: + List[str]: List of component IDs """ ids = set() if not isinstance(names, list): @@ -858,6 +996,20 @@ def get_ids(self, names: Union[str, List[str]] = None, collection: Optional[str] def get_components_by_ids(self, ids: List[str], return_dict_with_names: Optional[bool] = True): """ Get components by a list of IDs. + + Args: + ids (List[str]): + List of component IDs + return_dict_with_names (Optional[bool]): + Whether to return a dictionary with component names as keys: + + Returns: + Dict[str, Any]: Dictionary of components. + - If return_dict_with_names=True, keys are component names. + - If return_dict_with_names=False, keys are component IDs. + + Raises: + ValueError: If duplicate component names are found in the search results when return_dict_with_names=True """ components = {id: self.components[id] for id in ids} @@ -877,65 +1029,17 @@ def get_components_by_ids(self, ids: List[str], return_dict_with_names: Optional def get_components_by_names(self, names: List[str], collection: Optional[str] = None): """ Get components by a list of names, optionally filtered by collection. - """ - ids = self.get_ids(names, collection) - return self.get_components_by_ids(ids) - - -def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]: - """Summarizes a dictionary by finding common prefixes that share the same value. - - For a dictionary with dot-separated keys like: { - 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': [0.6], - 'down_blocks.1.attentions.1.transformer_blocks.1.attn2.processor': [0.6], - 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': [0.3], - } - - Returns a dictionary where keys are the shortest common prefixes and values are their shared values: { - 'down_blocks': [0.6], 'up_blocks': [0.3] - } - """ - # First group by values - convert lists to tuples to make them hashable - value_to_keys = {} - for key, value in d.items(): - value_tuple = tuple(value) if isinstance(value, list) else value - if value_tuple not in value_to_keys: - value_to_keys[value_tuple] = [] - value_to_keys[value_tuple].append(key) - def find_common_prefix(keys: List[str]) -> str: - """Find the shortest common prefix among a list of dot-separated keys.""" - if not keys: - return "" - if len(keys) == 1: - return keys[0] - - # Split all keys into parts - key_parts = [k.split(".") for k in keys] - - # Find how many initial parts are common - common_length = 0 - for parts in zip(*key_parts): - if len(set(parts)) == 1: # All parts at this position are the same - common_length += 1 - else: - break - - if common_length == 0: - return "" + Args: + names (List[str]): List of component names + collection (Optional[str]): Optional collection to filter by - # Return the common prefix - return ".".join(key_parts[0][:common_length]) + Returns: + Dict[str, Any]: Dictionary of components with component names as keys - # Create summary by finding common prefixes for each value group - summary = {} - for value_tuple, keys in value_to_keys.items(): - prefix = find_common_prefix(keys) - if prefix: # Only add if we found a common prefix - # Convert tuple back to list if it was originally a list - value = list(value_tuple) if isinstance(d[keys[0]], list) else value_tuple - summary[prefix] = value - else: - summary[""] = value # Use empty string if no common prefix + Raises: + ValueError: If duplicate component names are found in the search results + """ + ids = self.get_ids(names, collection) + return self.get_components_by_ids(ids) - return summary diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 57af0f220765..fbdd01af831d 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -83,12 +83,17 @@ class PipelineState: def set_input(self, key: str, value: Any, kwargs_type: str = None): """ - Add an input to the pipeline state with optional metadata. + Add an input to the immutable pipeline state, i.e, pipeline_state.inputs. + + The kwargs_type parameter allows you to associate inputs with specific input types. + For example, if you call set_input(prompt_embeds=..., kwargs_type="guider_kwargs"), + this input will be automatically fetched when a pipeline block has "guider_kwargs" + in its expected_inputs list. Args: key (str): The key for the input value (Any): The input value - kwargs_type (str): The kwargs_type to store with the input + kwargs_type (str): The kwargs_type with which the input is associated """ self.inputs[key] = value if kwargs_type is not None: @@ -99,12 +104,17 @@ def set_input(self, key: str, value: Any, kwargs_type: str = None): def set_intermediate(self, key: str, value: Any, kwargs_type: str = None): """ - Add an intermediate value to the pipeline state with optional metadata. + Add an intermediate value to the mutable pipeline state, i.e, pipeline_state.intermediates. + + The kwargs_type parameter allows you to associate intermediate values with specific input types. + For example, if you call set_intermediate(latents=..., kwargs_type="latents_kwargs"), + this intermediate value will be automatically fetched when a pipeline block has "latents_kwargs" + in its expected_intermediate_inputs list. Args: key (str): The key for the intermediate value value (Any): The intermediate value - kwargs_type (str): The kwargs_type to store with the intermediate value + kwargs_type (str): The kwargs_type with which the intermediate value is associated """ self.intermediates[key] = value if kwargs_type is not None: @@ -114,11 +124,31 @@ def set_intermediate(self, key: str, value: Any, kwargs_type: str = None): self.intermediate_kwargs[kwargs_type].append(key) def get_input(self, key: str, default: Any = None) -> Any: + """ + Get an input from the pipeline state. + + Args: + key (str): The key for the input + default (Any): The default value to return if the input is not found + + Returns: + Any: The input value + """ value = self.inputs.get(key, default) if value is not None: return deepcopy(value) def get_inputs(self, keys: List[str], default: Any = None) -> Dict[str, Any]: + """ + Get multiple inputs from the pipeline state. + + Args: + keys (List[str]): The keys for the inputs + default (Any): The default value to return if the input is not found + + Returns: + Dict[str, Any]: Dictionary of inputs with matching keys + """ return {key: self.inputs.get(key, default) for key in keys} def get_inputs_kwargs(self, kwargs_type: str) -> Dict[str, Any]: @@ -148,12 +178,38 @@ def get_intermediate_kwargs(self, kwargs_type: str) -> Dict[str, Any]: return self.get_intermediates(intermediate_names) def get_intermediate(self, key: str, default: Any = None) -> Any: + """ + Get an intermediate value from the pipeline state. + + Args: + key (str): The key for the intermediate value + default (Any): The default value to return if the intermediate value is not found + + Returns: + Any: The intermediate value + """ return self.intermediates.get(key, default) def get_intermediates(self, keys: List[str], default: Any = None) -> Dict[str, Any]: + """ + Get multiple intermediate values from the pipeline state. + + Args: + keys (List[str]): The keys for the intermediate values + default (Any): The default value to return if the intermediate value is not found + + Returns: + Dict[str, Any]: Dictionary of intermediate values with matching keys + """ return {key: self.intermediates.get(key, default) for key in keys} def to_dict(self) -> Dict[str, Any]: + """ + Convert PipelineState to a dictionary. + + Returns: + Dict[str, Any]: Dictionary containing all attributes of the PipelineState + """ return {**self.__dict__, "inputs": self.inputs, "intermediates": self.intermediates} def __repr__(self): @@ -258,6 +314,14 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin): """ Base class for all Pipeline Blocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks, LoopSequentialPipelineBlocks + + [`ModularPipelineBlocks`] provides method to load and save the defination of pipeline blocks. + + + + This is an experimental feature and is likely to change in the future. + + """ config_name = "config.json" @@ -350,9 +414,102 @@ def init_pipeline( collection=collection, ) return modular_pipeline + + @staticmethod + def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]: + """ + Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if current + default value is None and new default value is not None. Warns if multiple non-None default values exist for the + same input. + + Args: + named_input_lists: List of tuples containing (block_name, input_param_list) pairs + + Returns: + List[InputParam]: Combined list of unique InputParam objects + """ + combined_dict = {} # name -> InputParam + value_sources = {} # name -> block_name + + for block_name, inputs in named_input_lists: + for input_param in inputs: + if input_param.name is None and input_param.kwargs_type is not None: + input_name = "*_" + input_param.kwargs_type + else: + input_name = input_param.name + if input_name in combined_dict: + current_param = combined_dict[input_name] + if ( + current_param.default is not None + and input_param.default is not None + and current_param.default != input_param.default + ): + warnings.warn( + f"Multiple different default values found for input '{input_name}': " + f"{current_param.default} (from block '{value_sources[input_name]}') and " + f"{input_param.default} (from block '{block_name}'). Using {current_param.default}." + ) + if current_param.default is None and input_param.default is not None: + combined_dict[input_name] = input_param + value_sources[input_name] = block_name + else: + combined_dict[input_name] = input_param + value_sources[input_name] = block_name + + return list(combined_dict.values()) + + @staticmethod + def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> List[OutputParam]: + """ + Combines multiple lists of OutputParam objects from different blocks. For duplicate outputs, keeps the first + occurrence of each output name. + + Args: + named_output_lists: List of tuples containing (block_name, output_param_list) pairs + + Returns: + List[OutputParam]: Combined list of unique OutputParam objects + """ + combined_dict = {} # name -> OutputParam + + for block_name, outputs in named_output_lists: + for output_param in outputs: + if (output_param.name not in combined_dict) or ( + combined_dict[output_param.name].kwargs_type is None and output_param.kwargs_type is not None + ): + combined_dict[output_param.name] = output_param + + return list(combined_dict.values()) + + class PipelineBlock(ModularPipelineBlocks): + """ + A Pipeline Block is the basic building block of a Modular Pipeline. + + This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the + library implements for all the pipeline blocks (such as loading or saving etc.) + + + + This is an experimental feature and is likely to change in the future. + + + + Args: + description (str, optional): A description of the block, defaults to None. Define as a property in subclasses. + expected_components (List[ComponentSpec], optional): A list of components that are expected to be used in the block, defaults to []. To override, define as a property in subclasses. + expected_configs (List[ConfigSpec], optional): A list of configs that are expected to be used in the block, defaults to []. To override, define as a property in subclasses. + inputs (List[InputParam], optional): A list of inputs that are expected to be used in the block, defaults to []. To override, define as a property in subclasses. + intermediate_inputs (List[InputParam], optional): A list of intermediate inputs that are expected to be used in the block, defaults to []. To override, define as a property in subclasses. + intermediate_outputs (List[OutputParam], optional): A list of intermediate outputs that are expected to be used in the block, defaults to []. To override, define as a property in subclasses. + outputs (List[OutputParam], optional): A list of outputs that are expected to be used in the block, defaults to []. To override, define as a property in subclasses. + required_inputs (List[str], optional): A list of required inputs that are expected to be used in the block, defaults to []. To override, define as a property in subclasses. + required_intermediate_inputs (List[str], optional): A list of required intermediate inputs that are expected to be used in the block, defaults to []. To override, define as a property in subclasses. + required_intermediate_outputs (List[str], optional): A list of required intermediate outputs that are expected to be used in the block, defaults to []. To override, define as a property in subclasses. + """ + model_name = None def __init__(self): @@ -548,75 +705,18 @@ def set_block_state(self, state: PipelineState, block_state: BlockState): state.set_intermediate(param_name, param, input_param.kwargs_type) -def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]: - """ - Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if current - default value is None and new default value is not None. Warns if multiple non-None default values exist for the - same input. - - Args: - named_input_lists: List of tuples containing (block_name, input_param_list) pairs - - Returns: - List[InputParam]: Combined list of unique InputParam objects - """ - combined_dict = {} # name -> InputParam - value_sources = {} # name -> block_name - - for block_name, inputs in named_input_lists: - for input_param in inputs: - if input_param.name is None and input_param.kwargs_type is not None: - input_name = "*_" + input_param.kwargs_type - else: - input_name = input_param.name - if input_name in combined_dict: - current_param = combined_dict[input_name] - if ( - current_param.default is not None - and input_param.default is not None - and current_param.default != input_param.default - ): - warnings.warn( - f"Multiple different default values found for input '{input_name}': " - f"{current_param.default} (from block '{value_sources[input_name]}') and " - f"{input_param.default} (from block '{block_name}'). Using {current_param.default}." - ) - if current_param.default is None and input_param.default is not None: - combined_dict[input_name] = input_param - value_sources[input_name] = block_name - else: - combined_dict[input_name] = input_param - value_sources[input_name] = block_name - - return list(combined_dict.values()) - - -def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> List[OutputParam]: - """ - Combines multiple lists of OutputParam objects from different blocks. For duplicate outputs, keeps the first - occurrence of each output name. - - Args: - named_output_lists: List of tuples containing (block_name, output_param_list) pairs - - Returns: - List[OutputParam]: Combined list of unique OutputParam objects +class AutoPipelineBlocks(ModularPipelineBlocks): """ - combined_dict = {} # name -> OutputParam + A Pipeline Blocks that automatically selects a block to run based on the inputs. - for block_name, outputs in named_output_lists: - for output_param in outputs: - if (output_param.name not in combined_dict) or ( - combined_dict[output_param.name].kwargs_type is None and output_param.kwargs_type is not None - ): - combined_dict[output_param.name] = output_param + This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the + library implements for all the pipeline blocks (such as loading or saving etc.) - return list(combined_dict.values()) + + This is an experimental feature and is likely to change in the future. -class AutoPipelineBlocks(ModularPipelineBlocks): - """ - A class that automatically selects a block to run based on the inputs. + Attributes: block_classes: List of block classes to be used @@ -713,7 +813,7 @@ def required_intermediate_inputs(self) -> List[str]: @property def inputs(self) -> List[Tuple[str, Any]]: named_inputs = [(name, block.inputs) for name, block in self.sub_blocks.items()] - combined_inputs = combine_inputs(*named_inputs) + combined_inputs = self.combine_inputs(*named_inputs) # mark Required inputs only if that input is required by all the blocks for input_param in combined_inputs: if input_param.name in self.required_inputs: @@ -725,7 +825,7 @@ def inputs(self) -> List[Tuple[str, Any]]: @property def intermediate_inputs(self) -> List[str]: named_inputs = [(name, block.intermediate_inputs) for name, block in self.sub_blocks.items()] - combined_inputs = combine_inputs(*named_inputs) + combined_inputs = self.combine_inputs(*named_inputs) # mark Required inputs only if that input is required by all the blocks for input_param in combined_inputs: if input_param.name in self.required_intermediate_inputs: @@ -737,13 +837,13 @@ def intermediate_inputs(self) -> List[str]: @property def intermediate_outputs(self) -> List[str]: named_outputs = [(name, block.intermediate_outputs) for name, block in self.sub_blocks.items()] - combined_outputs = combine_outputs(*named_outputs) + combined_outputs = self.combine_outputs(*named_outputs) return combined_outputs @property def outputs(self) -> List[str]: named_outputs = [(name, block.outputs) for name, block in self.sub_blocks.items()] - combined_outputs = combine_outputs(*named_outputs) + combined_outputs = self.combine_outputs(*named_outputs) return combined_outputs @torch.no_grad() @@ -897,7 +997,20 @@ def doc(self): class SequentialPipelineBlocks(ModularPipelineBlocks): """ - A class that combines multiple pipeline block classes into one. When called, it will call each block in sequence. + A Pipeline Blocks that combines multiple pipeline block classes into one. When called, it will call each block in sequence. + + This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the + library implements for all the pipeline blocks (such as loading or saving etc.) + + + + This is an experimental feature and is likely to change in the future. + + + + Attributes: + block_classes: List of block classes to be used + block_names: List of prefixes for each block """ block_classes = [] @@ -990,7 +1103,7 @@ def inputs(self) -> List[Tuple[str, Any]]: def get_inputs(self): named_inputs = [(name, block.inputs) for name, block in self.sub_blocks.items()] - combined_inputs = combine_inputs(*named_inputs) + combined_inputs = self.combine_inputs(*named_inputs) # mark Required inputs only if that input is required any of the blocks for input_param in combined_inputs: if input_param.name in self.required_inputs: @@ -1036,7 +1149,7 @@ def intermediate_outputs(self) -> List[str]: # filter out them here so they do not end up as intermediate_outputs if name not in inp_names: named_outputs.append((name, block.intermediate_outputs)) - combined_outputs = combine_outputs(*named_outputs) + combined_outputs = self.combine_outputs(*named_outputs) return combined_outputs # YiYi TODO: I think we can remove the outputs property @@ -1258,11 +1371,23 @@ def doc(self): ) -# YiYi TODO: __repr__ class LoopSequentialPipelineBlocks(ModularPipelineBlocks): """ - A class that combines multiple pipeline block classes into a For Loop. When called, it will call each block in + A Pipeline blocks that combines multiple pipeline block classes into a For Loop. When called, it will call each block in sequence. + + This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the + library implements for all the pipeline blocks (such as loading or saving etc.) + + + + This is an experimental feature and is likely to change in the future. + + + + Attributes: + block_classes: List of block classes to be used + block_names: List of prefixes for each block """ model_name = None @@ -1343,7 +1468,7 @@ def expected_configs(self): def get_inputs(self): named_inputs = [(name, block.inputs) for name, block in self.sub_blocks.items()] named_inputs.append(("loop", self.loop_inputs)) - combined_inputs = combine_inputs(*named_inputs) + combined_inputs = self.combine_inputs(*named_inputs) # mark Required inputs only if that input is required any of the blocks for input_param in combined_inputs: if input_param.name in self.required_inputs: @@ -1423,7 +1548,7 @@ def required_intermediate_inputs(self) -> List[str]: @property def intermediate_outputs(self) -> List[str]: named_outputs = [(name, block.intermediate_outputs) for name, block in self.sub_blocks.items()] - combined_outputs = combine_outputs(*named_outputs) + combined_outputs = self.combine_outputs(*named_outputs) for output in self.loop_intermediate_outputs: if output.name not in {output.name for output in combined_outputs}: combined_outputs.append(output) @@ -1644,11 +1769,17 @@ def set_progress_bar_config(self, **kwargs): # YiYi TODO: # 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess) # 2. do we need ConfigSpec? the are basically just key/val kwargs -# 4. add validator for methods where we accpet kwargs to be passed to from_pretrained() +# 3. imnprove docstring and potentially add validator for methods where we accpet kwargs to be passed to from_pretrained/save_pretrained/load_default_components(), load_components() class ModularPipeline(ConfigMixin, PushToHubMixin): """ Base class for all Modular pipelines. + + + This is an experimental feature and is likely to change in the future. + + + Args: blocks: ModularPipelineBlocks, the blocks to be used in the pipeline """ @@ -1669,12 +1800,12 @@ def __init__( Initialize a ModularPipeline instance. This method sets up the pipeline by: - 1. creating default pipeline blocks if not provided - 2. gather component and config specifications based on the pipeline blocks's requirement (e.g. + - creating default pipeline blocks if not provided + - gather component and config specifications based on the pipeline blocks's requirement (e.g. expected_components, expected_configs) - 3. update the loading specs of from_pretrained components based on the modular_model_index.json file from + - update the loading specs of from_pretrained components based on the modular_model_index.json file from huggingface hub if `pretrained_model_name_or_path` is provided - 4. create defaultfrom_config components and register everything + - create defaultfrom_config components and register everything Args: blocks: `ModularPipelineBlocks` instance. If None, will attempt to load @@ -1866,7 +1997,7 @@ def load_default_components(self, **kwargs): Load from_pretrained components using the loading specs in the config dict. Args: - **kwargs: Additional arguments passed to `load_components()` method + **kwargs: Additional arguments passed to `from_pretrained` method, e.g. torch_dtype, cache_dir, etc. """ names = [ name diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index ee1f30d93d9d..4fac5ef4f2d5 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -66,7 +66,7 @@ def __repr__(self): # YiYi TODO: # 1. validate the dataclass fields -# 2. add a validator for create_* methods, make sure they are valid inputs to pass to from_pretrained() +# 2. improve the docstring and potentially add a validator for load methods, make sure they are valid inputs to pass to from_pretrained() @dataclass class ComponentSpec: """Specification for a pipeline component. diff --git a/src/diffusers/modular_pipelines/node_utils.py b/src/diffusers/modular_pipelines/node_utils.py index f644ddc9edea..93d09c22f439 100644 --- a/src/diffusers/modular_pipelines/node_utils.py +++ b/src/diffusers/modular_pipelines/node_utils.py @@ -347,6 +347,17 @@ def get_group_name(name, group_params_keys=DEFAULT_PARAMS_GROUPS_KEYS): class ModularNode(ConfigMixin): + """ + A ModularNode is a base class to build UI nodes using diffusers. Currently only supports Mellon. + It is a wrapper around a ModularPipelineBlocks object. + + + + This is an experimental feature and is likely to change in the future. + + + """ + config_name = "node_config.json" @classmethod @@ -496,7 +507,7 @@ def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs): self.register_to_config(**register_dict) def setup(self, components_manager, collection=None): - self.blocks.setup_loader(components_manager=components_manager, collection=collection) + self.pipeline = self.blocks.init_pipeline(components_manager=components_manager, collection=collection) self._components_manager = components_manager @property @@ -649,6 +660,6 @@ def process_inputs(self, **kwargs): def execute(self, **kwargs): params_components, params_run, return_output_names = self.process_inputs(**kwargs) - self.blocks.loader.update(**params_components) - output = self.blocks.run(**params_run, output=return_output_names) + self.pipeline.update_components(**params_components) + output = self.pipeline(**params_run, output=return_output_names) return output diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py index 90850ea53606..0c45857da742 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py @@ -44,6 +44,16 @@ class StableDiffusionXLModularPipeline( StableDiffusionXLLoraLoaderMixin, ModularIPAdapterMixin, ): + """ + A ModularPipeline for Stable Diffusion XL. + + + + This is an experimental feature and is likely to change in the future. + + + """ + @property def default_height(self): return self.default_sample_size * self.vae_scale_factor From 595581d6ba1240a11d77018132c58c4257681df8 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 8 Jul 2025 22:13:00 +0200 Subject: [PATCH 165/170] style --- .../modular_pipelines/components_manager.py | 49 +++++++------- .../modular_pipelines/modular_pipeline.py | 67 ++++++++++++------- src/diffusers/modular_pipelines/node_utils.py | 4 +- 3 files changed, 68 insertions(+), 52 deletions(-) diff --git a/src/diffusers/modular_pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py index 828f53c393bf..08e6d80fefd2 100644 --- a/src/diffusers/modular_pipelines/components_manager.py +++ b/src/diffusers/modular_pipelines/components_manager.py @@ -142,6 +142,7 @@ def custom_offload_with_hook( user_hook.attach() return user_hook + # this is the class that user can customize to implement their own offload strategy class AutoOffloadStrategy: """ @@ -277,37 +278,36 @@ def find_common_prefix(keys: List[str]) -> str: class ComponentsManager: """ A central registry and management system for model components across multiple pipelines. - - [`ComponentsManager`] provides a unified way to register, track, and reuse model components - (like UNet, VAE, text encoders, etc.) across different modular pipelines. It includes - features for duplicate detection, memory management, and component organization. - + + [`ComponentsManager`] provides a unified way to register, track, and reuse model components (like UNet, VAE, text + encoders, etc.) across different modular pipelines. It includes features for duplicate detection, memory + management, and component organization. + This is an experimental feature and is likely to change in the future. - + Example: ```python from diffusers import ComponentsManager - + # Create a components manager cm = ComponentsManager() - + # Add components cm.add("unet", unet_model, collection="sdxl") cm.add("vae", vae_model, collection="sdxl") - + # Enable auto offloading cm.enable_auto_cpu_offload(device="cuda") - + # Retrieve components unet = cm.get_one(name="unet", collection="sdxl") ``` """ - _available_info_fields = [ "model_id", "added_time", @@ -382,7 +382,7 @@ def add(self, name: str, component: Any, collection: Optional[str] = None): collection (Optional[str]): The collection to add the component to Returns: - str: The unique component ID, which is generated as "{name}_{id(component)}" where + str: The unique component ID, which is generated as "{name}_{id(component)}" where id(component) is Python's built-in unique identifier for the object """ component_id = f"{name}_{id(component)}" @@ -669,9 +669,9 @@ def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = "cuda" Args: device (Union[str, int, torch.device]): The execution device where models are moved for forward passes - memory_reserve_margin (str): The memory reserve margin to use, default is 3GB. This is the amount of - memory to keep free on the device to avoid running out of memory during - model execution (e.g., for intermediate activations, gradients, etc.) + memory_reserve_margin (str): The memory reserve margin to use, default is 3GB. This is the amount of + memory to keep free on the device to avoid running out of memory during model + execution (e.g., for intermediate activations, gradients, etc.) """ if not is_accelerate_available(): raise ImportError("Make sure to install accelerate to use auto_cpu_offload") @@ -727,11 +727,13 @@ def get_model_info( Args: component_id (str): Name of the component to get info for - fields (Optional[Union[str, List[str]]]): Field(s) to return. Can be a string for single field or list of fields. - If None, uses the available_info_fields setting. + fields (Optional[Union[str, List[str]]]): + Field(s) to return. Can be a string for single field or list of fields. If None, uses the + available_info_fields setting. Returns: - Dictionary containing requested component metadata. If fields is specified, returns only those fields. Otherwise, returns all fields. + Dictionary containing requested component metadata. If fields is specified, returns only those fields. + Otherwise, returns all fields. """ if component_id not in self.components: raise ValueError(f"Component '{component_id}' not found in ComponentsManager") @@ -938,10 +940,10 @@ def get_one( load_id: Optional[str] = None, ) -> Any: """ - Get a single component by either: - - searching name (pattern matching), collection, or load_id. + Get a single component by either: + - searching name (pattern matching), collection, or load_id. - passing in a component_id - Raises an error if multiple components match or none are found. + Raises an error if multiple components match or none are found. Args: component_id (Optional[str]): Optional component ID to get @@ -998,13 +1000,13 @@ def get_components_by_ids(self, ids: List[str], return_dict_with_names: Optional Get components by a list of IDs. Args: - ids (List[str]): + ids (List[str]): List of component IDs return_dict_with_names (Optional[bool]): Whether to return a dictionary with component names as keys: Returns: - Dict[str, Any]: Dictionary of components. + Dict[str, Any]: Dictionary of components. - If return_dict_with_names=True, keys are component names. - If return_dict_with_names=False, keys are component IDs. @@ -1042,4 +1044,3 @@ def get_components_by_names(self, names: List[str], collection: Optional[str] = """ ids = self.get_ids(names, collection) return self.get_components_by_ids(ids) - diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index fbdd01af831d..b99478cb58d1 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -85,10 +85,9 @@ def set_input(self, key: str, value: Any, kwargs_type: str = None): """ Add an input to the immutable pipeline state, i.e, pipeline_state.inputs. - The kwargs_type parameter allows you to associate inputs with specific input types. - For example, if you call set_input(prompt_embeds=..., kwargs_type="guider_kwargs"), - this input will be automatically fetched when a pipeline block has "guider_kwargs" - in its expected_inputs list. + The kwargs_type parameter allows you to associate inputs with specific input types. For example, if you call + set_input(prompt_embeds=..., kwargs_type="guider_kwargs"), this input will be automatically fetched when a + pipeline block has "guider_kwargs" in its expected_inputs list. Args: key (str): The key for the input @@ -106,10 +105,9 @@ def set_intermediate(self, key: str, value: Any, kwargs_type: str = None): """ Add an intermediate value to the mutable pipeline state, i.e, pipeline_state.intermediates. - The kwargs_type parameter allows you to associate intermediate values with specific input types. - For example, if you call set_intermediate(latents=..., kwargs_type="latents_kwargs"), - this intermediate value will be automatically fetched when a pipeline block has "latents_kwargs" - in its expected_intermediate_inputs list. + The kwargs_type parameter allows you to associate intermediate values with specific input types. For example, + if you call set_intermediate(latents=..., kwargs_type="latents_kwargs"), this intermediate value will be + automatically fetched when a pipeline block has "latents_kwargs" in its expected_intermediate_inputs list. Args: key (str): The key for the intermediate value @@ -414,13 +412,13 @@ def init_pipeline( collection=collection, ) return modular_pipeline - + @staticmethod def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]: """ - Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if current - default value is None and new default value is not None. Warns if multiple non-None default values exist for the - same input. + Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if + current default value is None and new default value is not None. Warns if multiple non-None default values + exist for the same input. Args: named_input_lists: List of tuples containing (block_name, input_param_list) pairs @@ -482,8 +480,6 @@ def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> return list(combined_dict.values()) - - class PipelineBlock(ModularPipelineBlocks): """ A Pipeline Block is the basic building block of a Modular Pipeline. @@ -499,15 +495,33 @@ class PipelineBlock(ModularPipelineBlocks): Args: description (str, optional): A description of the block, defaults to None. Define as a property in subclasses. - expected_components (List[ComponentSpec], optional): A list of components that are expected to be used in the block, defaults to []. To override, define as a property in subclasses. - expected_configs (List[ConfigSpec], optional): A list of configs that are expected to be used in the block, defaults to []. To override, define as a property in subclasses. - inputs (List[InputParam], optional): A list of inputs that are expected to be used in the block, defaults to []. To override, define as a property in subclasses. - intermediate_inputs (List[InputParam], optional): A list of intermediate inputs that are expected to be used in the block, defaults to []. To override, define as a property in subclasses. - intermediate_outputs (List[OutputParam], optional): A list of intermediate outputs that are expected to be used in the block, defaults to []. To override, define as a property in subclasses. - outputs (List[OutputParam], optional): A list of outputs that are expected to be used in the block, defaults to []. To override, define as a property in subclasses. - required_inputs (List[str], optional): A list of required inputs that are expected to be used in the block, defaults to []. To override, define as a property in subclasses. - required_intermediate_inputs (List[str], optional): A list of required intermediate inputs that are expected to be used in the block, defaults to []. To override, define as a property in subclasses. - required_intermediate_outputs (List[str], optional): A list of required intermediate outputs that are expected to be used in the block, defaults to []. To override, define as a property in subclasses. + expected_components (List[ComponentSpec], optional): + A list of components that are expected to be used in the block, defaults to []. To override, define as a + property in subclasses. + expected_configs (List[ConfigSpec], optional): + A list of configs that are expected to be used in the block, defaults to []. To override, define as a + property in subclasses. + inputs (List[InputParam], optional): + A list of inputs that are expected to be used in the block, defaults to []. To override, define as a + property in subclasses. + intermediate_inputs (List[InputParam], optional): + A list of intermediate inputs that are expected to be used in the block, defaults to []. To override, + define as a property in subclasses. + intermediate_outputs (List[OutputParam], optional): + A list of intermediate outputs that are expected to be used in the block, defaults to []. To override, + define as a property in subclasses. + outputs (List[OutputParam], optional): + A list of outputs that are expected to be used in the block, defaults to []. To override, define as a + property in subclasses. + required_inputs (List[str], optional): + A list of required inputs that are expected to be used in the block, defaults to []. To override, define as + a property in subclasses. + required_intermediate_inputs (List[str], optional): + A list of required intermediate inputs that are expected to be used in the block, defaults to []. To + override, define as a property in subclasses. + required_intermediate_outputs (List[str], optional): + A list of required intermediate outputs that are expected to be used in the block, defaults to []. To + override, define as a property in subclasses. """ model_name = None @@ -997,7 +1011,8 @@ def doc(self): class SequentialPipelineBlocks(ModularPipelineBlocks): """ - A Pipeline Blocks that combines multiple pipeline block classes into one. When called, it will call each block in sequence. + A Pipeline Blocks that combines multiple pipeline block classes into one. When called, it will call each block in + sequence. This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the library implements for all the pipeline blocks (such as loading or saving etc.) @@ -1373,8 +1388,8 @@ def doc(self): class LoopSequentialPipelineBlocks(ModularPipelineBlocks): """ - A Pipeline blocks that combines multiple pipeline block classes into a For Loop. When called, it will call each block in - sequence. + A Pipeline blocks that combines multiple pipeline block classes into a For Loop. When called, it will call each + block in sequence. This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the library implements for all the pipeline blocks (such as loading or saving etc.) diff --git a/src/diffusers/modular_pipelines/node_utils.py b/src/diffusers/modular_pipelines/node_utils.py index 93d09c22f439..fb9a03c755ac 100644 --- a/src/diffusers/modular_pipelines/node_utils.py +++ b/src/diffusers/modular_pipelines/node_utils.py @@ -348,8 +348,8 @@ def get_group_name(name, group_params_keys=DEFAULT_PARAMS_GROUPS_KEYS): class ModularNode(ConfigMixin): """ - A ModularNode is a base class to build UI nodes using diffusers. Currently only supports Mellon. - It is a wrapper around a ModularPipelineBlocks object. + A ModularNode is a base class to build UI nodes using diffusers. Currently only supports Mellon. It is a wrapper + around a ModularPipelineBlocks object. From de7cdf6287baf0189afd5cc3dd3ea4e7afd8f75c Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 9 Jul 2025 10:00:27 +0530 Subject: [PATCH 166/170] Merge modular diffusers with main (#11893) * [CI] Fix big GPU test marker (#11786) * update * update * First Block Cache (#11180) * update * modify flux single blocks to make compatible with cache techniques (without too much model-specific intrusion code) * remove debug logs * update * cache context for different batches of data * fix hs residual bug for single return outputs; support ltx * fix controlnet flux * support flux, ltx i2v, ltx condition * update * update * Update docs/source/en/api/cache.md * Update src/diffusers/hooks/hooks.py Co-authored-by: Dhruv Nair * address review comments pt. 1 * address review comments pt. 2 * cache context refacotr; address review pt. 3 * address review comments * metadata registration with decorators instead of centralized * support cogvideox * support mochi * fix * remove unused function * remove central registry based on review * update --------- Co-authored-by: Dhruv Nair * fix --------- Co-authored-by: Dhruv Nair --- .github/workflows/nightly_tests.yml | 2 +- docs/source/en/api/cache.md | 6 + src/diffusers/__init__.py | 4 + src/diffusers/hooks/__init__.py | 15 ++ src/diffusers/hooks/_helpers.py | 141 ++++++----- src/diffusers/hooks/first_block_cache.py | 227 ++++++++++++++++++ src/diffusers/hooks/hooks.py | 57 ++++- src/diffusers/hooks/layer_skip.py | 9 +- src/diffusers/models/cache_utils.py | 40 ++- .../models/controlnets/controlnet_flux.py | 10 +- .../transformers/transformer_cogview4.py | 2 + .../models/transformers/transformer_flux.py | 21 +- .../models/transformers/transformer_wan.py | 2 + .../pipelines/cogvideo/pipeline_cogvideox.py | 17 +- .../pipeline_cogvideox_fun_control.py | 17 +- .../pipeline_cogvideox_image2video.py | 19 +- .../pipeline_cogvideox_video2video.py | 17 +- .../pipelines/cogview4/pipeline_cogview4.py | 31 +-- src/diffusers/pipelines/flux/pipeline_flux.py | 41 ++-- .../hunyuan_video/pipeline_hunyuan_video.py | 34 +-- src/diffusers/pipelines/ltx/pipeline_ltx.py | 25 +- .../pipelines/ltx/pipeline_ltx_condition.py | 19 +- .../pipelines/ltx/pipeline_ltx_image2video.py | 25 +- .../pipelines/mochi/pipeline_mochi.py | 17 +- src/diffusers/pipelines/wan/pipeline_wan.py | 24 +- src/diffusers/utils/dummy_pt_objects.py | 19 ++ src/diffusers/utils/testing_utils.py | 4 + tests/conftest.py | 4 + tests/lora/test_lora_layers_flux.py | 3 - tests/lora/test_lora_layers_hunyuanvideo.py | 2 - tests/lora/test_lora_layers_sd3.py | 2 - tests/pipelines/cogvideo/test_cogvideox.py | 7 +- .../controlnet_flux/test_controlnet_flux.py | 2 - .../controlnet_sd3/test_controlnet_sd3.py | 2 - tests/pipelines/flux/test_pipeline_flux.py | 7 +- .../flux/test_pipeline_flux_redux.py | 2 - .../hunyuan_video/test_hunyuan_video.py | 7 +- tests/pipelines/ltx/test_ltx.py | 8 +- tests/pipelines/mochi/test_mochi.py | 8 +- .../test_pipeline_stable_diffusion_3.py | 2 - ...est_pipeline_stable_diffusion_3_img2img.py | 2 - tests/pipelines/test_pipelines_common.py | 52 +++- 42 files changed, 687 insertions(+), 268 deletions(-) create mode 100644 src/diffusers/hooks/first_block_cache.py diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml index 16e1a70b84fe..384f07506afe 100644 --- a/.github/workflows/nightly_tests.yml +++ b/.github/workflows/nightly_tests.yml @@ -248,7 +248,7 @@ jobs: BIG_GPU_MEMORY: 40 run: | python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \ - -m "big_gpu_with_torch_cuda" \ + -m "big_accelerator" \ --make-reports=tests_big_gpu_torch_cuda \ --report-log=tests_big_gpu_torch_cuda.log \ tests/ diff --git a/docs/source/en/api/cache.md b/docs/source/en/api/cache.md index e90cb32c54e5..9ba474208551 100644 --- a/docs/source/en/api/cache.md +++ b/docs/source/en/api/cache.md @@ -28,3 +28,9 @@ Cache methods speedup diffusion transformers by storing and reusing intermediate [[autodoc]] FasterCacheConfig [[autodoc]] apply_faster_cache + +### FirstBlockCacheConfig + +[[autodoc]] FirstBlockCacheConfig + +[[autodoc]] apply_first_block_cache diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 77971de41402..ab80ddffec50 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -147,11 +147,13 @@ _import_structure["hooks"].extend( [ "FasterCacheConfig", + "FirstBlockCacheConfig", "HookRegistry", "LayerSkipConfig", "PyramidAttentionBroadcastConfig", "SmoothedEnergyGuidanceConfig", "apply_faster_cache", + "apply_first_block_cache", "apply_layer_skip", "apply_pyramid_attention_broadcast", ] @@ -793,11 +795,13 @@ ) from .hooks import ( FasterCacheConfig, + FirstBlockCacheConfig, HookRegistry, LayerSkipConfig, PyramidAttentionBroadcastConfig, SmoothedEnergyGuidanceConfig, apply_faster_cache, + apply_first_block_cache, apply_layer_skip, apply_pyramid_attention_broadcast, ) diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 9d0e96e9e79e..525a0747da8b 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -1,8 +1,23 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from ..utils import is_torch_available if is_torch_available(): from .faster_cache import FasterCacheConfig, apply_faster_cache + from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache from .group_offloading import apply_group_offloading from .hooks import HookRegistry, ModelHook from .layer_skip import LayerSkipConfig, apply_layer_skip diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py index 1ef0cbf15551..960d14e6fa2a 100644 --- a/src/diffusers/hooks/_helpers.py +++ b/src/diffusers/hooks/_helpers.py @@ -12,23 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect from dataclasses import dataclass -from typing import Any, Callable, Type - -from ..models.attention import BasicTransformerBlock -from ..models.attention_processor import AttnProcessor2_0 -from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock -from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor, CogView4TransformerBlock -from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock -from ..models.transformers.transformer_hunyuan_video import ( - HunyuanVideoSingleTransformerBlock, - HunyuanVideoTokenReplaceSingleTransformerBlock, - HunyuanVideoTokenReplaceTransformerBlock, - HunyuanVideoTransformerBlock, -) -from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock -from ..models.transformers.transformer_mochi import MochiTransformerBlock -from ..models.transformers.transformer_wan import WanTransformerBlock +from typing import Any, Callable, Dict, Type @dataclass @@ -38,40 +24,90 @@ class AttentionProcessorMetadata: @dataclass class TransformerBlockMetadata: - skip_block_output_fn: Callable[[Any], Any] return_hidden_states_index: int = None return_encoder_hidden_states_index: int = None + _cls: Type = None + _cached_parameter_indices: Dict[str, int] = None + + def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None): + kwargs = kwargs or {} + if identifier in kwargs: + return kwargs[identifier] + if self._cached_parameter_indices is not None: + return args[self._cached_parameter_indices[identifier]] + if self._cls is None: + raise ValueError("Model class is not set for metadata.") + parameters = list(inspect.signature(self._cls.forward).parameters.keys()) + parameters = parameters[1:] # skip `self` + self._cached_parameter_indices = {param: i for i, param in enumerate(parameters)} + if identifier not in self._cached_parameter_indices: + raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.") + index = self._cached_parameter_indices[identifier] + if index >= len(args): + raise ValueError(f"Expected {index} arguments but got {len(args)}.") + return args[index] + class AttentionProcessorRegistry: _registry = {} + # TODO(aryan): this is only required for the time being because we need to do the registrations + # for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular + # import errors because of the models imported in this file. + _is_registered = False @classmethod def register(cls, model_class: Type, metadata: AttentionProcessorMetadata): + cls._register() cls._registry[model_class] = metadata @classmethod def get(cls, model_class: Type) -> AttentionProcessorMetadata: + cls._register() if model_class not in cls._registry: raise ValueError(f"Model class {model_class} not registered.") return cls._registry[model_class] + @classmethod + def _register(cls): + if cls._is_registered: + return + cls._is_registered = True + _register_attention_processors_metadata() + class TransformerBlockRegistry: _registry = {} + # TODO(aryan): this is only required for the time being because we need to do the registrations + # for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular + # import errors because of the models imported in this file. + _is_registered = False @classmethod def register(cls, model_class: Type, metadata: TransformerBlockMetadata): + cls._register() + metadata._cls = model_class cls._registry[model_class] = metadata @classmethod def get(cls, model_class: Type) -> TransformerBlockMetadata: + cls._register() if model_class not in cls._registry: raise ValueError(f"Model class {model_class} not registered.") return cls._registry[model_class] + @classmethod + def _register(cls): + if cls._is_registered: + return + cls._is_registered = True + _register_transformer_blocks_metadata() + def _register_attention_processors_metadata(): + from ..models.attention_processor import AttnProcessor2_0 + from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor + # AttnProcessor2_0 AttentionProcessorRegistry.register( model_class=AttnProcessor2_0, @@ -90,11 +126,24 @@ def _register_attention_processors_metadata(): def _register_transformer_blocks_metadata(): + from ..models.attention import BasicTransformerBlock + from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock + from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock + from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock + from ..models.transformers.transformer_hunyuan_video import ( + HunyuanVideoSingleTransformerBlock, + HunyuanVideoTokenReplaceSingleTransformerBlock, + HunyuanVideoTokenReplaceTransformerBlock, + HunyuanVideoTransformerBlock, + ) + from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock + from ..models.transformers.transformer_mochi import MochiTransformerBlock + from ..models.transformers.transformer_wan import WanTransformerBlock + # BasicTransformerBlock TransformerBlockRegistry.register( model_class=BasicTransformerBlock, metadata=TransformerBlockMetadata( - skip_block_output_fn=_skip_block_output_fn_BasicTransformerBlock, return_hidden_states_index=0, return_encoder_hidden_states_index=None, ), @@ -104,7 +153,6 @@ def _register_transformer_blocks_metadata(): TransformerBlockRegistry.register( model_class=CogVideoXBlock, metadata=TransformerBlockMetadata( - skip_block_output_fn=_skip_block_output_fn_CogVideoXBlock, return_hidden_states_index=0, return_encoder_hidden_states_index=1, ), @@ -114,7 +162,6 @@ def _register_transformer_blocks_metadata(): TransformerBlockRegistry.register( model_class=CogView4TransformerBlock, metadata=TransformerBlockMetadata( - skip_block_output_fn=_skip_block_output_fn_CogView4TransformerBlock, return_hidden_states_index=0, return_encoder_hidden_states_index=1, ), @@ -124,7 +171,6 @@ def _register_transformer_blocks_metadata(): TransformerBlockRegistry.register( model_class=FluxTransformerBlock, metadata=TransformerBlockMetadata( - skip_block_output_fn=_skip_block_output_fn_FluxTransformerBlock, return_hidden_states_index=1, return_encoder_hidden_states_index=0, ), @@ -132,7 +178,6 @@ def _register_transformer_blocks_metadata(): TransformerBlockRegistry.register( model_class=FluxSingleTransformerBlock, metadata=TransformerBlockMetadata( - skip_block_output_fn=_skip_block_output_fn_FluxSingleTransformerBlock, return_hidden_states_index=1, return_encoder_hidden_states_index=0, ), @@ -142,7 +187,6 @@ def _register_transformer_blocks_metadata(): TransformerBlockRegistry.register( model_class=HunyuanVideoTransformerBlock, metadata=TransformerBlockMetadata( - skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTransformerBlock, return_hidden_states_index=0, return_encoder_hidden_states_index=1, ), @@ -150,7 +194,6 @@ def _register_transformer_blocks_metadata(): TransformerBlockRegistry.register( model_class=HunyuanVideoSingleTransformerBlock, metadata=TransformerBlockMetadata( - skip_block_output_fn=_skip_block_output_fn_HunyuanVideoSingleTransformerBlock, return_hidden_states_index=0, return_encoder_hidden_states_index=1, ), @@ -158,7 +201,6 @@ def _register_transformer_blocks_metadata(): TransformerBlockRegistry.register( model_class=HunyuanVideoTokenReplaceTransformerBlock, metadata=TransformerBlockMetadata( - skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock, return_hidden_states_index=0, return_encoder_hidden_states_index=1, ), @@ -166,7 +208,6 @@ def _register_transformer_blocks_metadata(): TransformerBlockRegistry.register( model_class=HunyuanVideoTokenReplaceSingleTransformerBlock, metadata=TransformerBlockMetadata( - skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock, return_hidden_states_index=0, return_encoder_hidden_states_index=1, ), @@ -176,7 +217,6 @@ def _register_transformer_blocks_metadata(): TransformerBlockRegistry.register( model_class=LTXVideoTransformerBlock, metadata=TransformerBlockMetadata( - skip_block_output_fn=_skip_block_output_fn_LTXVideoTransformerBlock, return_hidden_states_index=0, return_encoder_hidden_states_index=None, ), @@ -186,7 +226,6 @@ def _register_transformer_blocks_metadata(): TransformerBlockRegistry.register( model_class=MochiTransformerBlock, metadata=TransformerBlockMetadata( - skip_block_output_fn=_skip_block_output_fn_MochiTransformerBlock, return_hidden_states_index=0, return_encoder_hidden_states_index=1, ), @@ -196,7 +235,6 @@ def _register_transformer_blocks_metadata(): TransformerBlockRegistry.register( model_class=WanTransformerBlock, metadata=TransformerBlockMetadata( - skip_block_output_fn=_skip_block_output_fn_WanTransformerBlock, return_hidden_states_index=0, return_encoder_hidden_states_index=None, ), @@ -223,49 +261,4 @@ def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, * _skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states _skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states - - -def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs): - hidden_states = kwargs.get("hidden_states", None) - if hidden_states is None and len(args) > 0: - hidden_states = args[0] - return hidden_states - - -def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs): - hidden_states = kwargs.get("hidden_states", None) - encoder_hidden_states = kwargs.get("encoder_hidden_states", None) - if hidden_states is None and len(args) > 0: - hidden_states = args[0] - if encoder_hidden_states is None and len(args) > 1: - encoder_hidden_states = args[1] - return hidden_states, encoder_hidden_states - - -def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states(self, *args, **kwargs): - hidden_states = kwargs.get("hidden_states", None) - encoder_hidden_states = kwargs.get("encoder_hidden_states", None) - if hidden_states is None and len(args) > 0: - hidden_states = args[0] - if encoder_hidden_states is None and len(args) > 1: - encoder_hidden_states = args[1] - return encoder_hidden_states, hidden_states - - -_skip_block_output_fn_BasicTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states -_skip_block_output_fn_CogVideoXBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states -_skip_block_output_fn_CogView4TransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states -_skip_block_output_fn_FluxTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states -_skip_block_output_fn_FluxSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states -_skip_block_output_fn_HunyuanVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states -_skip_block_output_fn_HunyuanVideoSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states -_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states -_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states -_skip_block_output_fn_LTXVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states -_skip_block_output_fn_MochiTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states -_skip_block_output_fn_WanTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states # fmt: on - - -_register_attention_processors_metadata() -_register_transformer_blocks_metadata() diff --git a/src/diffusers/hooks/first_block_cache.py b/src/diffusers/hooks/first_block_cache.py new file mode 100644 index 000000000000..40ae8c5a263a --- /dev/null +++ b/src/diffusers/hooks/first_block_cache.py @@ -0,0 +1,227 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Tuple, Union + +import torch + +from ..utils import get_logger +from ..utils.torch_utils import unwrap_module +from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS +from ._helpers import TransformerBlockRegistry +from .hooks import BaseState, HookRegistry, ModelHook, StateManager + + +logger = get_logger(__name__) # pylint: disable=invalid-name + +_FBC_LEADER_BLOCK_HOOK = "fbc_leader_block_hook" +_FBC_BLOCK_HOOK = "fbc_block_hook" + + +@dataclass +class FirstBlockCacheConfig: + r""" + Configuration for [First Block + Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching). + + Args: + threshold (`float`, defaults to `0.05`): + The threshold to determine whether or not a forward pass through all layers of the model is required. A + higher threshold usually results in a forward pass through a lower number of layers and faster inference, + but might lead to poorer generation quality. A lower threshold may not result in significant generation + speedup. The threshold is compared against the absmean difference of the residuals between the current and + cached outputs from the first transformer block. If the difference is below the threshold, the forward pass + is skipped. + """ + + threshold: float = 0.05 + + +class FBCSharedBlockState(BaseState): + def __init__(self) -> None: + super().__init__() + + self.head_block_output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None + self.head_block_residual: torch.Tensor = None + self.tail_block_residuals: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None + self.should_compute: bool = True + + def reset(self): + self.tail_block_residuals = None + self.should_compute = True + + +class FBCHeadBlockHook(ModelHook): + _is_stateful = True + + def __init__(self, state_manager: StateManager, threshold: float): + self.state_manager = state_manager + self.threshold = threshold + self._metadata = None + + def initialize_hook(self, module): + unwrapped_module = unwrap_module(module) + self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__) + return module + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs) + + output = self.fn_ref.original_forward(*args, **kwargs) + is_output_tuple = isinstance(output, tuple) + + if is_output_tuple: + hidden_states_residual = output[self._metadata.return_hidden_states_index] - original_hidden_states + else: + hidden_states_residual = output - original_hidden_states + + shared_state: FBCSharedBlockState = self.state_manager.get_state() + hidden_states = encoder_hidden_states = None + should_compute = self._should_compute_remaining_blocks(hidden_states_residual) + shared_state.should_compute = should_compute + + if not should_compute: + # Apply caching + if is_output_tuple: + hidden_states = ( + shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index] + ) + else: + hidden_states = shared_state.tail_block_residuals[0] + output + + if self._metadata.return_encoder_hidden_states_index is not None: + assert is_output_tuple + encoder_hidden_states = ( + shared_state.tail_block_residuals[1] + output[self._metadata.return_encoder_hidden_states_index] + ) + + if is_output_tuple: + return_output = [None] * len(output) + return_output[self._metadata.return_hidden_states_index] = hidden_states + return_output[self._metadata.return_encoder_hidden_states_index] = encoder_hidden_states + return_output = tuple(return_output) + else: + return_output = hidden_states + output = return_output + else: + if is_output_tuple: + head_block_output = [None] * len(output) + head_block_output[0] = output[self._metadata.return_hidden_states_index] + head_block_output[1] = output[self._metadata.return_encoder_hidden_states_index] + else: + head_block_output = output + shared_state.head_block_output = head_block_output + shared_state.head_block_residual = hidden_states_residual + + return output + + def reset_state(self, module): + self.state_manager.reset() + return module + + @torch.compiler.disable + def _should_compute_remaining_blocks(self, hidden_states_residual: torch.Tensor) -> bool: + shared_state = self.state_manager.get_state() + if shared_state.head_block_residual is None: + return True + prev_hidden_states_residual = shared_state.head_block_residual + absmean = (hidden_states_residual - prev_hidden_states_residual).abs().mean() + prev_hidden_states_absmean = prev_hidden_states_residual.abs().mean() + diff = (absmean / prev_hidden_states_absmean).item() + return diff > self.threshold + + +class FBCBlockHook(ModelHook): + def __init__(self, state_manager: StateManager, is_tail: bool = False): + super().__init__() + self.state_manager = state_manager + self.is_tail = is_tail + self._metadata = None + + def initialize_hook(self, module): + unwrapped_module = unwrap_module(module) + self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__) + return module + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs) + original_encoder_hidden_states = None + if self._metadata.return_encoder_hidden_states_index is not None: + original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs( + "encoder_hidden_states", args, kwargs + ) + + shared_state = self.state_manager.get_state() + + if shared_state.should_compute: + output = self.fn_ref.original_forward(*args, **kwargs) + if self.is_tail: + hidden_states_residual = encoder_hidden_states_residual = None + if isinstance(output, tuple): + hidden_states_residual = ( + output[self._metadata.return_hidden_states_index] - shared_state.head_block_output[0] + ) + encoder_hidden_states_residual = ( + output[self._metadata.return_encoder_hidden_states_index] - shared_state.head_block_output[1] + ) + else: + hidden_states_residual = output - shared_state.head_block_output + shared_state.tail_block_residuals = (hidden_states_residual, encoder_hidden_states_residual) + return output + + if original_encoder_hidden_states is None: + return_output = original_hidden_states + else: + return_output = [None, None] + return_output[self._metadata.return_hidden_states_index] = original_hidden_states + return_output[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states + return_output = tuple(return_output) + return return_output + + +def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None: + state_manager = StateManager(FBCSharedBlockState, (), {}) + remaining_blocks = [] + + for name, submodule in module.named_children(): + if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList): + continue + for index, block in enumerate(submodule): + remaining_blocks.append((f"{name}.{index}", block)) + + head_block_name, head_block = remaining_blocks.pop(0) + tail_block_name, tail_block = remaining_blocks.pop(-1) + + logger.debug(f"Applying FBCHeadBlockHook to '{head_block_name}'") + _apply_fbc_head_block_hook(head_block, state_manager, config.threshold) + + for name, block in remaining_blocks: + logger.debug(f"Applying FBCBlockHook to '{name}'") + _apply_fbc_block_hook(block, state_manager) + + logger.debug(f"Applying FBCBlockHook to tail block '{tail_block_name}'") + _apply_fbc_block_hook(tail_block, state_manager, is_tail=True) + + +def _apply_fbc_head_block_hook(block: torch.nn.Module, state_manager: StateManager, threshold: float) -> None: + registry = HookRegistry.check_if_exists_or_initialize(block) + hook = FBCHeadBlockHook(state_manager, threshold) + registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK) + + +def _apply_fbc_block_hook(block: torch.nn.Module, state_manager: StateManager, is_tail: bool = False) -> None: + registry = HookRegistry.check_if_exists_or_initialize(block) + hook = FBCBlockHook(state_manager, is_tail) + registry.register_hook(hook, _FBC_BLOCK_HOOK) diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index 96231aadc3f7..6e097e5882a0 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -18,11 +18,44 @@ import torch from ..utils.logging import get_logger +from ..utils.torch_utils import unwrap_module logger = get_logger(__name__) # pylint: disable=invalid-name +class BaseState: + def reset(self, *args, **kwargs) -> None: + raise NotImplementedError( + "BaseState::reset is not implemented. Please implement this method in the derived class." + ) + + +class StateManager: + def __init__(self, state_cls: BaseState, init_args=None, init_kwargs=None): + self._state_cls = state_cls + self._init_args = init_args if init_args is not None else () + self._init_kwargs = init_kwargs if init_kwargs is not None else {} + self._state_cache = {} + self._current_context = None + + def get_state(self): + if self._current_context is None: + raise ValueError("No context is set. Please set a context before retrieving the state.") + if self._current_context not in self._state_cache.keys(): + self._state_cache[self._current_context] = self._state_cls(*self._init_args, **self._init_kwargs) + return self._state_cache[self._current_context] + + def set_context(self, name: str) -> None: + self._current_context = name + + def reset(self, *args, **kwargs) -> None: + for name, state in list(self._state_cache.items()): + state.reset(*args, **kwargs) + self._state_cache.pop(name) + self._current_context = None + + class ModelHook: r""" A hook that contains callbacks to be executed just before and after the forward method of a model. @@ -99,6 +132,14 @@ def reset_state(self, module: torch.nn.Module): raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.") return module + def _set_context(self, module: torch.nn.Module, name: str) -> None: + # Iterate over all attributes of the hook to see if any of them have the type `StateManager`. If so, call `set_context` on them. + for attr_name in dir(self): + attr = getattr(self, attr_name) + if isinstance(attr, StateManager): + attr.set_context(name) + return module + class HookFunctionReference: def __init__(self) -> None: @@ -211,9 +252,10 @@ def reset_stateful_hooks(self, recurse: bool = True) -> None: hook.reset_state(self._module_ref) if recurse: - for module_name, module in self._module_ref.named_modules(): + for module_name, module in unwrap_module(self._module_ref).named_modules(): if module_name == "": continue + module = unwrap_module(module) if hasattr(module, "_diffusers_hook"): module._diffusers_hook.reset_stateful_hooks(recurse=False) @@ -223,6 +265,19 @@ def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry module._diffusers_hook = cls(module) return module._diffusers_hook + def _set_context(self, name: Optional[str] = None) -> None: + for hook_name in reversed(self._hook_order): + hook = self.hooks[hook_name] + if hook._is_stateful: + hook._set_context(self._module_ref, name) + + for module_name, module in unwrap_module(self._module_ref).named_modules(): + if module_name == "": + continue + module = unwrap_module(module) + if hasattr(module, "_diffusers_hook"): + module._diffusers_hook._set_context(name) + def __repr__(self) -> str: registry_repr = "" for i, hook_name in enumerate(self._hook_order): diff --git a/src/diffusers/hooks/layer_skip.py b/src/diffusers/hooks/layer_skip.py index 487a1876d605..14e6c2f8881e 100644 --- a/src/diffusers/hooks/layer_skip.py +++ b/src/diffusers/hooks/layer_skip.py @@ -150,7 +150,14 @@ def initialize_hook(self, module): def new_forward(self, module: torch.nn.Module, *args, **kwargs): if math.isclose(self.dropout, 1.0): - output = self._metadata.skip_block_output_fn(module, *args, **kwargs) + original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs) + if self._metadata.return_encoder_hidden_states_index is None: + output = original_hidden_states + else: + original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs( + "encoder_hidden_states", args, kwargs + ) + output = (original_hidden_states, original_encoder_hidden_states) else: output = self.fn_ref.original_forward(*args, **kwargs) output = torch.nn.functional.dropout(output, p=self.dropout) diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index 3fd1ca6e9d3d..605c0d588c8c 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import contextmanager + from ..utils.logging import get_logger @@ -25,6 +27,7 @@ class CacheMixin: Supported caching techniques: - [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) - [FasterCache](https://huggingface.co/papers/2410.19355) + - [FirstBlockCache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching) """ _cache_config = None @@ -62,8 +65,10 @@ def enable_cache(self, config) -> None: from ..hooks import ( FasterCacheConfig, + FirstBlockCacheConfig, PyramidAttentionBroadcastConfig, apply_faster_cache, + apply_first_block_cache, apply_pyramid_attention_broadcast, ) @@ -72,31 +77,36 @@ def enable_cache(self, config) -> None: f"Caching has already been enabled with {type(self._cache_config)}. To apply a new caching technique, please disable the existing one first." ) - if isinstance(config, PyramidAttentionBroadcastConfig): - apply_pyramid_attention_broadcast(self, config) - elif isinstance(config, FasterCacheConfig): + if isinstance(config, FasterCacheConfig): apply_faster_cache(self, config) + elif isinstance(config, FirstBlockCacheConfig): + apply_first_block_cache(self, config) + elif isinstance(config, PyramidAttentionBroadcastConfig): + apply_pyramid_attention_broadcast(self, config) else: raise ValueError(f"Cache config {type(config)} is not supported.") self._cache_config = config def disable_cache(self) -> None: - from ..hooks import FasterCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig + from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK + from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK if self._cache_config is None: logger.warning("Caching techniques have not been enabled, so there's nothing to disable.") return - if isinstance(self._cache_config, PyramidAttentionBroadcastConfig): - registry = HookRegistry.check_if_exists_or_initialize(self) - registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True) - elif isinstance(self._cache_config, FasterCacheConfig): - registry = HookRegistry.check_if_exists_or_initialize(self) + registry = HookRegistry.check_if_exists_or_initialize(self) + if isinstance(self._cache_config, FasterCacheConfig): registry.remove_hook(_FASTER_CACHE_DENOISER_HOOK, recurse=True) registry.remove_hook(_FASTER_CACHE_BLOCK_HOOK, recurse=True) + elif isinstance(self._cache_config, FirstBlockCacheConfig): + registry.remove_hook(_FBC_LEADER_BLOCK_HOOK, recurse=True) + registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True) + elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig): + registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True) else: raise ValueError(f"Cache config {type(self._cache_config)} is not supported.") @@ -106,3 +116,15 @@ def _reset_stateful_cache(self, recurse: bool = True) -> None: from ..hooks import HookRegistry HookRegistry.check_if_exists_or_initialize(self).reset_stateful_hooks(recurse=recurse) + + @contextmanager + def cache_context(self, name: str): + r"""Context manager that provides additional methods for cache management.""" + from ..hooks import HookRegistry + + registry = HookRegistry.check_if_exists_or_initialize(self) + registry._set_context(name) + + yield + + registry._set_context(None) diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py index d8e99ee45eb6..063ff5bd8e2d 100644 --- a/src/diffusers/models/controlnets/controlnet_flux.py +++ b/src/diffusers/models/controlnets/controlnet_flux.py @@ -343,25 +343,25 @@ def forward( ) block_samples = block_samples + (hidden_states,) - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - single_block_samples = () for index_block, block in enumerate(self.single_transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func( + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( block, hidden_states, + encoder_hidden_states, temb, image_rotary_emb, ) else: - hidden_states = block( + encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, ) - single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],) + single_block_samples = single_block_samples + (hidden_states,) # controlnet block controlnet_block_samples = () diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py index e4144d0c8e57..dc45befb98fa 100644 --- a/src/diffusers/models/transformers/transformer_cogview4.py +++ b/src/diffusers/models/transformers/transformer_cogview4.py @@ -21,6 +21,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward from ..attention_processor import Attention from ..cache_utils import CacheMixin @@ -453,6 +454,7 @@ def __call__( return hidden_states, encoder_hidden_states +@maybe_allow_in_graph class CogView4TransformerBlock(nn.Module): def __init__( self, diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 3af1de2ad0be..3a7202d0f43f 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -79,10 +79,14 @@ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, def forward( self, hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, ) -> torch.Tensor: + text_seq_len = encoder_hidden_states.shape[1] + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + residual = hidden_states norm_hidden_states, gate = self.norm(hidden_states, emb=temb) mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) @@ -100,7 +104,8 @@ def forward( if hidden_states.dtype == torch.float16: hidden_states = hidden_states.clip(-65504, 65504) - return hidden_states + encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:] + return encoder_hidden_states, hidden_states @maybe_allow_in_graph @@ -507,20 +512,21 @@ def forward( ) else: hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) for index_block, block in enumerate(self.single_transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func( + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( block, hidden_states, + encoder_hidden_states, temb, image_rotary_emb, ) else: - hidden_states = block( + encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, joint_attention_kwargs=joint_attention_kwargs, @@ -530,12 +536,7 @@ def forward( if controlnet_single_block_samples is not None: interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) interval_control = int(np.ceil(interval_control)) - hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( - hidden_states[:, encoder_hidden_states.shape[1] :, ...] - + controlnet_single_block_samples[index_block // interval_control] - ) - - hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] + hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control] hidden_states = self.norm_out(hidden_states, temb) output = self.proj_out(hidden_states) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 5fb71b69f7ac..bdb9201e62cf 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -22,6 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward from ..attention_processor import Attention from ..cache_utils import CacheMixin @@ -249,6 +250,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return freqs_cos, freqs_sin +@maybe_allow_in_graph class WanTransformerBlock(nn.Module): def __init__( self, diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index f08a3c35c2fa..3c5994172c79 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -718,14 +718,15 @@ def __call__( timestep = t.expand(latent_model_input.shape[0]) # predict noise model_output - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - image_rotary_emb=image_rotary_emb, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond_uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_pred.float() # perform guidance diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py index fe3e8ae388c7..cf6ccebc476d 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py @@ -784,14 +784,15 @@ def __call__( timestep = t.expand(latent_model_input.shape[0]) # predict noise model_output - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - image_rotary_emb=image_rotary_emb, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond_uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_pred.float() # perform guidance diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py index a982f4b27557..d1f02ca9c95e 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py @@ -831,15 +831,16 @@ def __call__( timestep = t.expand(latent_model_input.shape[0]) # predict noise model_output - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - ofs=ofs_emb, - image_rotary_emb=image_rotary_emb, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond_uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + ofs=ofs_emb, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_pred.float() # perform guidance diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py index 7c50bdcb7def..230c8ca296ba 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py @@ -799,14 +799,15 @@ def __call__( timestep = t.expand(latent_model_input.shape[0]) # predict noise model_output - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - image_rotary_emb=image_rotary_emb, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond_uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_pred.float() # perform guidance diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py index 880253459e1e..d8374b694f0e 100644 --- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py +++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py @@ -619,22 +619,10 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]) - noise_pred_cond = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - original_size=original_size, - target_size=target_size, - crop_coords=crops_coords_top_left, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - # perform guidance - if self.do_classifier_free_guidance: - noise_pred_uncond = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred_cond = self.transformer( hidden_states=latent_model_input, - encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states=prompt_embeds, timestep=timestep, original_size=original_size, target_size=target_size, @@ -643,6 +631,19 @@ def __call__( return_dict=False, )[0] + # perform guidance + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_pred_uncond = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=negative_prompt_embeds, + timestep=timestep, + original_size=original_size, + target_size=target_size, + crop_coords=crops_coords_top_left, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) else: noise_pred = noise_pred_cond diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 4c83ae7405f4..073d94750a02 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -912,32 +912,35 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] - - if do_true_cfg: - if negative_image_embeds is not None: - self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds - neg_noise_pred = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, guidance=guidance, - pooled_projections=negative_pooled_prompt_embeds, - encoder_hidden_states=negative_prompt_embeds, - txt_ids=negative_text_ids, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, img_ids=latent_image_ids, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] + + if do_true_cfg: + if negative_image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds + + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=negative_pooled_prompt_embeds, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) # compute the previous noisy sample x_t -> x_t-1 diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index b617e4f8b26a..2cbb4af2b4cc 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -693,28 +693,30 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - encoder_attention_mask=prompt_attention_mask, - pooled_projections=pooled_prompt_embeds, - guidance=guidance, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - if do_true_cfg: - neg_noise_pred = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, - encoder_attention_mask=negative_prompt_attention_mask, - pooled_projections=negative_pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + pooled_projections=pooled_prompt_embeds, guidance=guidance, attention_kwargs=attention_kwargs, return_dict=False, )[0] + + if do_true_cfg: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_attention_mask=negative_prompt_attention_mask, + pooled_projections=negative_pooled_prompt_embeds, + guidance=guidance, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) # compute the previous noisy sample x_t -> x_t-1 diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index 3b58b4a45a45..77ba75170037 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -757,18 +757,19 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - encoder_attention_mask=prompt_attention_mask, - num_frames=latent_num_frames, - height=latent_height, - width=latent_width, - rope_interpolation_scale=rope_interpolation_scale, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond_uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + encoder_attention_mask=prompt_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + rope_interpolation_scale=rope_interpolation_scale, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_pred.float() if self.do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py index fa9ee4fc7b87..217478f418ed 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py @@ -1177,15 +1177,16 @@ def __call__( if is_conditioning_image_or_video: timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0) - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - encoder_attention_mask=prompt_attention_mask, - video_coords=video_coords, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond_uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + encoder_attention_mask=prompt_attention_mask, + video_coords=video_coords, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py index 99412b6962a3..8793d81377cc 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -830,18 +830,19 @@ def __call__( timestep = t.expand(latent_model_input.shape[0]) timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask) - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - encoder_attention_mask=prompt_attention_mask, - num_frames=latent_num_frames, - height=latent_height, - width=latent_width, - rope_interpolation_scale=rope_interpolation_scale, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond_uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + encoder_attention_mask=prompt_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + rope_interpolation_scale=rope_interpolation_scale, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_pred.float() if self.do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index 7712b415242c..3c0f908296df 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -671,14 +671,15 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - encoder_attention_mask=prompt_attention_mask, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond_uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + encoder_attention_mask=prompt_attention_mask, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] # Mochi CFG + Sampling runs in FP32 noise_pred = noise_pred.to(torch.float32) diff --git a/src/diffusers/pipelines/wan/pipeline_wan.py b/src/diffusers/pipelines/wan/pipeline_wan.py index 6df66118b068..d14dac91f14a 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan.py +++ b/src/diffusers/pipelines/wan/pipeline_wan.py @@ -533,22 +533,24 @@ def __call__( latent_model_input = latents.to(transformer_dtype) timestep = t.expand(latents.shape[0]) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - if self.do_classifier_free_guidance: - noise_uncond = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states=prompt_embeds, attention_kwargs=attention_kwargs, return_dict=False, )[0] + + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) # compute the previous noisy sample x_t -> x_t-1 diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index ea1999da1853..247769306b53 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -137,6 +137,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class FirstBlockCacheConfig(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class HookRegistry(metaclass=DummyObject): _backends = ["torch"] @@ -201,6 +216,10 @@ def apply_faster_cache(*args, **kwargs): requires_backends(apply_faster_cache, ["torch"]) +def apply_first_block_cache(*args, **kwargs): + requires_backends(apply_first_block_cache, ["torch"]) + + def apply_layer_skip(*args, **kwargs): requires_backends(apply_layer_skip, ["torch"]) diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index e5da39c1d865..ebb3d7055319 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -421,6 +421,10 @@ def require_big_accelerator(test_case): Decorator marking a test that requires a bigger hardware accelerator (24GB) for execution. Some example pipelines: Flux, SD3, Cog, etc. """ + import pytest + + test_case = pytest.mark.big_accelerator(test_case) + if not is_torch_available(): return unittest.skip("test requires PyTorch")(test_case) diff --git a/tests/conftest.py b/tests/conftest.py index 7e9c4e8f3948..3237fb9c7bb0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -30,6 +30,10 @@ warnings.simplefilter(action="ignore", category=FutureWarning) +def pytest_configure(config): + config.addinivalue_line("markers", "big_accelerator: marks tests as requiring big accelerator resources") + + def pytest_addoption(parser): from diffusers.utils.testing_utils import pytest_addoption_shared diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 336ac2246fd2..95f1e137e94b 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -20,7 +20,6 @@ import unittest import numpy as np -import pytest import safetensors.torch import torch from parameterized import parameterized @@ -813,7 +812,6 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): @require_torch_accelerator @require_peft_backend @require_big_accelerator -@pytest.mark.big_accelerator class FluxLoRAIntegrationTests(unittest.TestCase): """internal note: The integration slices were obtained on audace. @@ -960,7 +958,6 @@ def test_flux_xlabs_load_lora_with_single_blocks(self): @require_torch_accelerator @require_peft_backend @require_big_accelerator -@pytest.mark.big_accelerator class FluxControlLoRAIntegrationTests(unittest.TestCase): num_inference_steps = 10 seed = 0 diff --git a/tests/lora/test_lora_layers_hunyuanvideo.py b/tests/lora/test_lora_layers_hunyuanvideo.py index 19e31f320d0a..4cbd6523e712 100644 --- a/tests/lora/test_lora_layers_hunyuanvideo.py +++ b/tests/lora/test_lora_layers_hunyuanvideo.py @@ -17,7 +17,6 @@ import unittest import numpy as np -import pytest import torch from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast @@ -198,7 +197,6 @@ def test_simple_inference_with_text_lora_save_load(self): @require_torch_accelerator @require_peft_backend @require_big_accelerator -@pytest.mark.big_accelerator class HunyuanVideoLoRAIntegrationTests(unittest.TestCase): """internal note: The integration slices were obtained on DGX. diff --git a/tests/lora/test_lora_layers_sd3.py b/tests/lora/test_lora_layers_sd3.py index 8a8f2a676df1..8928ccbac2dd 100644 --- a/tests/lora/test_lora_layers_sd3.py +++ b/tests/lora/test_lora_layers_sd3.py @@ -17,7 +17,6 @@ import unittest import numpy as np -import pytest import torch from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel @@ -139,7 +138,6 @@ def test_multiple_wrong_adapter_name_raises_error(self): @require_torch_accelerator @require_peft_backend @require_big_accelerator -@pytest.mark.big_accelerator class SD3LoraIntegrationTests(unittest.TestCase): pipeline_class = StableDiffusion3Img2ImgPipeline repo_id = "stabilityai/stable-diffusion-3-medium-diffusers" diff --git a/tests/pipelines/cogvideo/test_cogvideox.py b/tests/pipelines/cogvideo/test_cogvideox.py index c72558978115..a6cb558513e7 100644 --- a/tests/pipelines/cogvideo/test_cogvideox.py +++ b/tests/pipelines/cogvideo/test_cogvideox.py @@ -33,6 +33,7 @@ from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import ( FasterCacheTesterMixin, + FirstBlockCacheTesterMixin, PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, check_qkv_fusion_matches_attn_procs_length, @@ -45,7 +46,11 @@ class CogVideoXPipelineFastTests( - PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase + PipelineTesterMixin, + PyramidAttentionBroadcastTesterMixin, + FasterCacheTesterMixin, + FirstBlockCacheTesterMixin, + unittest.TestCase, ): pipeline_class = CogVideoXPipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux.py b/tests/pipelines/controlnet_flux/test_controlnet_flux.py index 5ee94b09bab0..5b336edc7a88 100644 --- a/tests/pipelines/controlnet_flux/test_controlnet_flux.py +++ b/tests/pipelines/controlnet_flux/test_controlnet_flux.py @@ -17,7 +17,6 @@ import unittest import numpy as np -import pytest import torch from huggingface_hub import hf_hub_download from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast @@ -211,7 +210,6 @@ def test_flux_image_output_shape(self): @nightly @require_big_accelerator -@pytest.mark.big_accelerator class FluxControlNetPipelineSlowTests(unittest.TestCase): pipeline_class = FluxControlNetPipeline diff --git a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py index 712c26b0a2f9..1f1f800bcf23 100644 --- a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py +++ b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py @@ -18,7 +18,6 @@ from typing import Optional import numpy as np -import pytest import torch from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel @@ -221,7 +220,6 @@ def test_xformers_attention_forwardGenerator_pass(self): @slow @require_big_accelerator -@pytest.mark.big_accelerator class StableDiffusion3ControlNetPipelineSlowTests(unittest.TestCase): pipeline_class = StableDiffusion3ControlNetPipeline diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index cbdf617d71ec..0df0e028ff06 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -2,7 +2,6 @@ import unittest import numpy as np -import pytest import torch from huggingface_hub import hf_hub_download from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel @@ -25,6 +24,7 @@ from ..test_pipelines_common import ( FasterCacheTesterMixin, + FirstBlockCacheTesterMixin, FluxIPAdapterTesterMixin, PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, @@ -34,11 +34,12 @@ class FluxPipelineFastTests( - unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, + FirstBlockCacheTesterMixin, + unittest.TestCase, ): pipeline_class = FluxPipeline params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) @@ -224,7 +225,6 @@ def test_flux_true_cfg(self): @nightly @require_big_accelerator -@pytest.mark.big_accelerator class FluxPipelineSlowTests(unittest.TestCase): pipeline_class = FluxPipeline repo_id = "black-forest-labs/FLUX.1-schnell" @@ -312,7 +312,6 @@ def test_flux_inference(self): @slow @require_big_accelerator -@pytest.mark.big_accelerator class FluxIPAdapterPipelineSlowTests(unittest.TestCase): pipeline_class = FluxPipeline repo_id = "black-forest-labs/FLUX.1-dev" diff --git a/tests/pipelines/flux/test_pipeline_flux_redux.py b/tests/pipelines/flux/test_pipeline_flux_redux.py index b8f36dfd3cd3..b73050a64df9 100644 --- a/tests/pipelines/flux/test_pipeline_flux_redux.py +++ b/tests/pipelines/flux/test_pipeline_flux_redux.py @@ -2,7 +2,6 @@ import unittest import numpy as np -import pytest import torch from diffusers import FluxPipeline, FluxPriorReduxPipeline @@ -19,7 +18,6 @@ @slow @require_big_accelerator -@pytest.mark.big_accelerator class FluxReduxSlowTests(unittest.TestCase): pipeline_class = FluxPriorReduxPipeline repo_id = "black-forest-labs/FLUX.1-Redux-dev" diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_video.py b/tests/pipelines/hunyuan_video/test_hunyuan_video.py index ecc5eba96448..10101af75cee 100644 --- a/tests/pipelines/hunyuan_video/test_hunyuan_video.py +++ b/tests/pipelines/hunyuan_video/test_hunyuan_video.py @@ -33,6 +33,7 @@ from ..test_pipelines_common import ( FasterCacheTesterMixin, + FirstBlockCacheTesterMixin, PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np, @@ -43,7 +44,11 @@ class HunyuanVideoPipelineFastTests( - PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase + PipelineTesterMixin, + PyramidAttentionBroadcastTesterMixin, + FasterCacheTesterMixin, + FirstBlockCacheTesterMixin, + unittest.TestCase, ): pipeline_class = HunyuanVideoPipeline params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) diff --git a/tests/pipelines/ltx/test_ltx.py b/tests/pipelines/ltx/test_ltx.py index 1d1eb0823406..bf0c7fde591f 100644 --- a/tests/pipelines/ltx/test_ltx.py +++ b/tests/pipelines/ltx/test_ltx.py @@ -23,13 +23,13 @@ from diffusers.utils.testing_utils import enable_full_determinism, torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin, to_np +from ..test_pipelines_common import FirstBlockCacheTesterMixin, PipelineTesterMixin, to_np enable_full_determinism() -class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class LTXPipelineFastTests(PipelineTesterMixin, FirstBlockCacheTesterMixin, unittest.TestCase): pipeline_class = LTXPipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} batch_params = TEXT_TO_IMAGE_BATCH_PARAMS @@ -49,7 +49,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): test_layerwise_casting = True test_group_offloading = True - def get_dummy_components(self): + def get_dummy_components(self, num_layers: int = 1): torch.manual_seed(0) transformer = LTXVideoTransformer3DModel( in_channels=8, @@ -59,7 +59,7 @@ def get_dummy_components(self): num_attention_heads=4, attention_head_dim=8, cross_attention_dim=32, - num_layers=1, + num_layers=num_layers, caption_channels=32, ) diff --git a/tests/pipelines/mochi/test_mochi.py b/tests/pipelines/mochi/test_mochi.py index 5b00261b06ee..f1684cce72e1 100644 --- a/tests/pipelines/mochi/test_mochi.py +++ b/tests/pipelines/mochi/test_mochi.py @@ -17,7 +17,6 @@ import unittest import numpy as np -import pytest import torch from transformers import AutoTokenizer, T5EncoderModel @@ -33,13 +32,15 @@ ) from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import FasterCacheTesterMixin, PipelineTesterMixin, to_np +from ..test_pipelines_common import FasterCacheTesterMixin, FirstBlockCacheTesterMixin, PipelineTesterMixin, to_np enable_full_determinism() -class MochiPipelineFastTests(PipelineTesterMixin, FasterCacheTesterMixin, unittest.TestCase): +class MochiPipelineFastTests( + PipelineTesterMixin, FasterCacheTesterMixin, FirstBlockCacheTesterMixin, unittest.TestCase +): pipeline_class = MochiPipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} batch_params = TEXT_TO_IMAGE_BATCH_PARAMS @@ -268,7 +269,6 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2): @nightly @require_torch_accelerator @require_big_accelerator -@pytest.mark.big_accelerator class MochiPipelineIntegrationTests(unittest.TestCase): prompt = "A painting of a squirrel eating a burger." diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py index 577ac4ebdd4b..2179ec8e226b 100644 --- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py +++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py @@ -2,7 +2,6 @@ import unittest import numpy as np -import pytest import torch from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel @@ -233,7 +232,6 @@ def test_skip_guidance_layers(self): @slow @require_big_accelerator -@pytest.mark.big_accelerator class StableDiffusion3PipelineSlowTests(unittest.TestCase): pipeline_class = StableDiffusion3Pipeline repo_id = "stabilityai/stable-diffusion-3-medium-diffusers" diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py index f5b5e63a810a..7f913cb63ddf 100644 --- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py +++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py @@ -3,7 +3,6 @@ import unittest import numpy as np -import pytest import torch from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel @@ -168,7 +167,6 @@ def test_multi_vae(self): @slow @require_big_accelerator -@pytest.mark.big_accelerator class StableDiffusion3Img2ImgPipelineSlowTests(unittest.TestCase): pipeline_class = StableDiffusion3Img2ImgPipeline repo_id = "stabilityai/stable-diffusion-3-medium-diffusers" diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index f87778b260c9..13c25ccaa469 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -33,6 +33,7 @@ ) from diffusers.hooks import apply_group_offloading from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook +from diffusers.hooks.first_block_cache import FirstBlockCacheConfig from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin @@ -2648,7 +2649,7 @@ def run_forward(pipe): self.faster_cache_config.current_timestep_callback = lambda: pipe.current_timestep pipe = create_pipe() pipe.transformer.enable_cache(self.faster_cache_config) - output = run_forward(pipe).flatten().flatten() + output = run_forward(pipe).flatten() image_slice_faster_cache_enabled = np.concatenate((output[:8], output[-8:])) # Run inference with FasterCache disabled @@ -2755,6 +2756,55 @@ def faster_cache_state_check_callback(pipe, i, t, kwargs): self.assertTrue(state.cache is None, "Cache should be reset to None.") +# TODO(aryan, dhruv): the cache tester mixins should probably be rewritten so that more models can be tested out +# of the box once there is better cache support/implementation +class FirstBlockCacheTesterMixin: + # threshold is intentionally set higher than usual values since we're testing with random unconverged models + # that will not satisfy the expected properties of the denoiser for caching to be effective + first_block_cache_config = FirstBlockCacheConfig(threshold=0.8) + + def test_first_block_cache_inference(self, expected_atol: float = 0.1): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + + def create_pipe(): + torch.manual_seed(0) + num_layers = 2 + components = self.get_dummy_components(num_layers=num_layers) + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + return pipe + + def run_forward(pipe): + torch.manual_seed(0) + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 4 + return pipe(**inputs)[0] + + # Run inference without FirstBlockCache + pipe = create_pipe() + output = run_forward(pipe).flatten() + original_image_slice = np.concatenate((output[:8], output[-8:])) + + # Run inference with FirstBlockCache enabled + pipe = create_pipe() + pipe.transformer.enable_cache(self.first_block_cache_config) + output = run_forward(pipe).flatten() + image_slice_fbc_enabled = np.concatenate((output[:8], output[-8:])) + + # Run inference with FirstBlockCache disabled + pipe.transformer.disable_cache() + output = run_forward(pipe).flatten() + image_slice_fbc_disabled = np.concatenate((output[:8], output[-8:])) + + assert np.allclose(original_image_slice, image_slice_fbc_enabled, atol=expected_atol), ( + "FirstBlockCache outputs should not differ much." + ) + assert np.allclose(original_image_slice, image_slice_fbc_disabled, atol=1e-4), ( + "Outputs from normal inference and after disabling cache should not differ." + ) + + # Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used. # This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a # reference image. From a935bea6b7aaf7c71448806d0fb81e24f6fcae37 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 9 Jul 2025 07:08:09 +0200 Subject: [PATCH 167/170] big doc updategit status! --- docs/source/en/_toctree.yml | 20 +- .../modular_diffusers/auto_pipeline_blocks.md | 316 +++++++ .../modular_diffusers/components_manager.md | 28 +- .../en/modular_diffusers/end_to_end_guide.md | 26 +- .../loop_sequential_pipeline_blocks.md | 193 +++++ .../modular_diffusers_states.md | 59 ++ ...getting_started.md => modular_pipeline.md} | 112 +-- docs/source/en/modular_diffusers/overview.md | 42 + .../en/modular_diffusers/pipeline_block.md | 279 ++++++ .../sequential_pipeline_blocks.md | 189 ++++ .../write_own_pipeline_block.md | 817 ------------------ 11 files changed, 1179 insertions(+), 902 deletions(-) create mode 100644 docs/source/en/modular_diffusers/auto_pipeline_blocks.md create mode 100644 docs/source/en/modular_diffusers/loop_sequential_pipeline_blocks.md create mode 100644 docs/source/en/modular_diffusers/modular_diffusers_states.md rename docs/source/en/modular_diffusers/{getting_started.md => modular_pipeline.md} (91%) create mode 100644 docs/source/en/modular_diffusers/overview.md create mode 100644 docs/source/en/modular_diffusers/pipeline_block.md create mode 100644 docs/source/en/modular_diffusers/sequential_pipeline_blocks.md delete mode 100644 docs/source/en/modular_diffusers/write_own_pipeline_block.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 24e21a7a4acb..bb2c847f8aff 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -94,14 +94,24 @@ title: API Reference title: Hybrid Inference - sections: - - local: modular_diffusers/getting_started - title: Getting Started + - local: modular_diffusers/overview + title: Overview + - local: modular_diffusers/modular_pipeline + title: Modular Pipeline - local: modular_diffusers/components_manager title: Components Manager - - local: modular_diffusers/write_own_pipeline_block - title: Write your own pipeline block + - local: modular_diffusers/modular_diffusers_states + title: Modular Diffusers States + - local: modular_diffusers/pipeline_block + title: Pipeline Block + - local: modular_diffusers/sequential_pipeline_blocks + title: Sequential Pipeline Blocks + - local: modular_diffusers/loop_sequential_pipeline_blocks + title: Loop Sequential Pipeline Blocks + - local: modular_diffusers/auto_pipeline_blocks + title: Auto Pipeline Blocks - local: modular_diffusers/end_to_end_guide - title: End-to-End Developer Guide + title: End-to-End Example title: Modular Diffusers - sections: - local: using-diffusers/consisid diff --git a/docs/source/en/modular_diffusers/auto_pipeline_blocks.md b/docs/source/en/modular_diffusers/auto_pipeline_blocks.md new file mode 100644 index 000000000000..50c3250512d1 --- /dev/null +++ b/docs/source/en/modular_diffusers/auto_pipeline_blocks.md @@ -0,0 +1,316 @@ + + +# AutoPipelineBlocks + + + +🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes. + + + +`AutoPipelineBlocks` is a subclass of `ModularPipelineBlocks`. It is a multi-block that automatically selects which sub-blocks to run based on the inputs provided at runtime, creating conditional workflows that adapt to different scenarios. The main purpose is convenience and portability - for developers, you can package everything into one workflow, making it easier to share and use. + +In this tutorial, we will show you how to create an `AutoPipelineBlocks` and learn more about how the conditional selection works. + + + +Other types of multi-blocks include [SequentialPipelineBlocks](sequential_pipeline_blocks.md) (for linear workflows) and [LoopSequentialPipelineBlocks](loop_sequential_pipeline_blocks.md) (for iterative workflows). For information on creating individual blocks, see the [PipelineBlock guide](pipeline_block.md). + +Additionally, like all `ModularPipelineBlocks`, `AutoPipelineBlocks` are definitions/specifications, not runnable pipelines. You need to convert them into a `ModularPipeline` to actually execute them. For information on creating and running pipelines, see the [Modular Pipeline guide](modular_pipeline.md). + + + +For example, you might want to support text-to-image and image-to-image tasks. Instead of creating two separate pipelines, you can create an `AutoPipelineBlocks` that automatically chooses the workflow based on whether an `image` input is provided. + +Let's see an example. We'll use the helper function from the [PipelineBlock guide](./pipeline_block.md) to create our blocks: + +**Helper Function** + +```py +from diffusers.modular_pipelines import PipelineBlock, InputParam, OutputParam +import torch + +def make_block(inputs=[], intermediate_inputs=[], intermediate_outputs=[], block_fn=None, description=None): + class TestBlock(PipelineBlock): + model_name = "test" + + @property + def inputs(self): + return inputs + + @property + def intermediate_inputs(self): + return intermediate_inputs + + @property + def intermediate_outputs(self): + return intermediate_outputs + + @property + def description(self): + return description if description is not None else "" + + def __call__(self, components, state): + block_state = self.get_block_state(state) + if block_fn is not None: + block_state = block_fn(block_state, state) + self.set_block_state(state, block_state) + return components, state + + return TestBlock +``` + +Now let's create a dummy `AutoPipelineBlocks` that includes dummy text-to-image, image-to-image, and inpaint pipelines. + + +```py +from diffusers.modular_pipelines import AutoPipelineBlocks + +# These are dummy blocks and we only focus on "inputs" for our purpose +inputs = [InputParam(name="prompt")] +# block_fn prints out which workflow is running so we can see the execution order at runtime +block_fn = lambda x, y: print("running the text-to-image workflow") +block_t2i_cls = make_block(inputs=inputs, block_fn=block_fn, description="I'm a text-to-image workflow!") + +inputs = [InputParam(name="prompt"), InputParam(name="image")] +block_fn = lambda x, y: print("running the image-to-image workflow") +block_i2i_cls = make_block(inputs=inputs, block_fn=block_fn, description="I'm a image-to-image workflow!") + +inputs = [InputParam(name="prompt"), InputParam(name="image"), InputParam(name="mask")] +block_fn = lambda x, y: print("running the inpaint workflow") +block_inpaint_cls = make_block(inputs=inputs, block_fn=block_fn, description="I'm a inpaint workflow!") + +class AutoImageBlocks(AutoPipelineBlocks): + # List of sub-block classes to choose from + block_classes = [block_inpaint_cls, block_i2i_cls, block_t2i_cls] + # Names for each block in the same order + block_names = ["inpaint", "img2img", "text2img"] + # Trigger inputs that determine which block to run + # - "mask" triggers inpaint workflow + # - "image" triggers img2img workflow (but only if mask is not provided) + # - if none of above, runs the text2img workflow (default) + block_trigger_inputs = ["mask", "image", None] + # Description is extremely important for AutoPipelineBlocks + @property + def description(self): + return ( + "Pipeline generates images given different types of conditions!\n" + + "This is an auto pipeline block that works for text2img, img2img and inpainting tasks.\n" + + " - inpaint workflow is run when `mask` is provided.\n" + + " - img2img workflow is run when `image` is provided (but only when `mask` is not provided).\n" + + " - text2img workflow is run when neither `image` nor `mask` is provided.\n" + ) + +# Create the blocks +auto_blocks = AutoImageBlocks() +# convert to pipeline +auto_pipeline = auto_blocks.init_pipeline() +``` + +Now we have created an `AutoPipelineBlocks` that contains 3 sub-blocks. Notice the warning message at the top - this automatically appears in every `ModularPipelineBlocks` that contains `AutoPipelineBlocks` to remind end users that dynamic block selection happens at runtime. + +```py +AutoImageBlocks( + Class: AutoPipelineBlocks + + ==================================================================================================== + This pipeline contains blocks that are selected at runtime based on inputs. + Trigger Inputs: ['mask', 'image'] + ==================================================================================================== + + + Description: Pipeline generates images given different types of conditions! + This is an auto pipeline block that works for text2img, img2img and inpainting tasks. + - inpaint workflow is run when `mask` is provided. + - img2img workflow is run when `image` is provided (but only when `mask` is not provided). + - text2img workflow is run when neither `image` nor `mask` is provided. + + + + Sub-Blocks: + • inpaint [trigger: mask] (TestBlock) + Description: I'm a inpaint workflow! + + • img2img [trigger: image] (TestBlock) + Description: I'm a image-to-image workflow! + + • text2img [default] (TestBlock) + Description: I'm a text-to-image workflow! + +) +``` + +Check out the documentation with `print(auto_pipeline.doc)`: + +```py +>>> print(auto_pipeline.doc) +class AutoImageBlocks + + Pipeline generates images given different types of conditions! + This is an auto pipeline block that works for text2img, img2img and inpainting tasks. + - inpaint workflow is run when `mask` is provided. + - img2img workflow is run when `image` is provided (but only when `mask` is not provided). + - text2img workflow is run when neither `image` nor `mask` is provided. + + Inputs: + + prompt (`None`, *optional*): + + image (`None`, *optional*): + + mask (`None`, *optional*): +``` + +There is a fundamental trade-off of AutoPipelineBlocks: it trades clarity for convenience. While it is really easy for packaging multiple workflows, it can become confusing without proper documentation. e.g. if we just throw a pipeline at you and tell you that it contains 3 sub-blocks and takes 3 inputs `prompt`, `image` and `mask`, and ask you to run an image-to-image workflow: if you don't have any prior knowledge on how these pipelines work, you would be pretty clueless, right? + +This pipeline we just made though, has a docstring that shows all available inputs and workflows and explains how to use each with different inputs. So it's really helpful for users. For example, it's clear that you need to pass `image` to run img2img. This is why the description field is absolutely critical for AutoPipelineBlocks. We highly recommend you to explain the conditional logic very well for each `AutoPipelineBlocks` you would make. We also recommend to always test individual pipelines first before packaging them into AutoPipelineBlocks. + +Let's run this auto pipeline with different inputs to see if the conditional logic works as described. Remember that we have added `print` in each `PipelineBlock`'s `__call__` method to print out its workflow name, so it should be easy to tell which one is running: + +```py +>>> _ = auto_pipeline(image="image", mask="mask") +running the inpaint workflow +>>> _ = auto_pipeline(image="image") +running the image-to-image workflow +>>> _ = auto_pipeline(prompt="prompt") +running the text-to-image workflow +>>> _ = auto_pipeline(image="prompt", mask="mask") +running the inpaint workflow +``` + +However, even with documentation, it can become very confusing when AutoPipelineBlocks are combined with other blocks. The complexity grows quickly when you have nested AutoPipelineBlocks or use them as sub-blocks in larger pipelines. + +Let's make another `AutoPipelineBlocks` - this one only contains one block, and it does not include `None` in its `block_trigger_inputs` (which corresponds to the default block to run when none of the trigger inputs are provided). This means this block will be skipped if the trigger input (`ip_adapter_image`) is not provided at runtime. + +```py +from diffusers.modular_pipelines import SequentialPipelineBlocks, InsertableDict +inputs = [InputParam(name="ip_adapter_image")] +block_fn = lambda x, y: print("running the ip-adapter workflow") +block_ipa_cls = make_block(inputs=inputs, block_fn=block_fn, description="I'm a IP-adapter workflow!") + +class AutoIPAdapter(AutoPipelineBlocks): + block_classes = [block_ipa_cls] + block_names = ["ip-adapter"] + block_trigger_inputs = ["ip_adapter_image"] + @property + def description(self): + return "Run IP Adapter step if `ip_adapter_image` is provided." +``` + +Now let's combine these 2 auto blocks together into a `SequentialPipelineBlocks`: + +```py +auto_ipa_blocks = AutoIPAdapter() +blocks_dict = InsertableDict() +blocks_dict["ip-adapter"] = auto_ipa_blocks +blocks_dict["image-generation"] = auto_blocks +all_blocks = SequentialPipelineBlocks.from_blocks_dict(blocks_dict) +pipeline = all_blocks.init_pipeline() +``` + +Let's take a look: now things get more confusing. In this particular example, you could still try to explain the conditional logic in the `description` field here - there are only 4 possible execution paths so it's doable. However, since this is a `SequentialPipelineBlocks` that could contain many more blocks, the complexity can quickly get out of hand as the number of blocks increases. + +```py +>>> all_blocks +SequentialPipelineBlocks( + Class: ModularPipelineBlocks + + ==================================================================================================== + This pipeline contains blocks that are selected at runtime based on inputs. + Trigger Inputs: ['image', 'mask', 'ip_adapter_image'] + Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('image')`). + ==================================================================================================== + + + Description: + + + Sub-Blocks: + [0] ip-adapter (AutoIPAdapter) + Description: Run IP Adapter step if `ip_adapter_image` is provided. + + + [1] image-generation (AutoImageBlocks) + Description: Pipeline generates images given different types of conditions! + This is an auto pipeline block that works for text2img, img2img and inpainting tasks. + - inpaint workflow is run when `mask` is provided. + - img2img workflow is run when `image` is provided (but only when `mask` is not provided). + - text2img workflow is run when neither `image` nor `mask` is provided. + + +) + +``` + +This is when the `get_execution_blocks()` method comes in handy - it basically extracts a `SequentialPipelineBlocks` that only contains the blocks that are actually run based on your inputs. + +Let's try some examples: + +`mask`: we expect it to skip the first ip-adapter since `ip_adapter_image` is not provided, and then run the inpaint for the second block. + +```py +>>> all_blocks.get_execution_blocks('mask') +SequentialPipelineBlocks( + Class: ModularPipelineBlocks + + Description: + + + Sub-Blocks: + [0] image-generation (TestBlock) + Description: I'm a inpaint workflow! + +) +``` + +Let's also actually run the pipeline to confirm: + +```py +>>> _ = pipeline(mask="mask") +skipping auto block: AutoIPAdapter +running the inpaint workflow +``` + +Try a few more: + +```py +print(f"inputs: ip_adapter_image:") +blocks_select = all_blocks.get_execution_blocks('ip_adapter_image') +print(f"expected_execution_blocks: {blocks_select}") +print(f"actual execution blocks:") +_ = pipeline(ip_adapter_image="ip_adapter_image", prompt="prompt") +# expect to see ip-adapter + text2img + +print(f"inputs: image:") +blocks_select = all_blocks.get_execution_blocks('image') +print(f"expected_execution_blocks: {blocks_select}") +print(f"actual execution blocks:") +_ = pipeline(image="image", prompt="prompt") +# expect to see img2img + +print(f"inputs: prompt:") +blocks_select = all_blocks.get_execution_blocks('prompt') +print(f"expected_execution_blocks: {blocks_select}") +print(f"actual execution blocks:") +_ = pipeline(prompt="prompt") +# expect to see text2img (prompt is not a trigger input so fallback to default) + +print(f"inputs: mask + ip_adapter_image:") +blocks_select = all_blocks.get_execution_blocks('mask','ip_adapter_image') +print(f"expected_execution_blocks: {blocks_select}") +print(f"actual execution blocks:") +_ = pipeline(mask="mask", ip_adapter_image="ip_adapter_image") +# expect to see ip-adapter + inpaint +``` + +In summary, `AutoPipelineBlocks` is a good tool for packaging multiple workflows into a single, convenient interface and it can greatly simplify the user experience. However, always provide clear descriptions explaining the conditional logic, test individual pipelines first before combining them, and use `get_execution_blocks()` to understand runtime behavior in complex compositions. \ No newline at end of file diff --git a/docs/source/en/modular_diffusers/components_manager.md b/docs/source/en/modular_diffusers/components_manager.md index 316119409ac9..222944e83703 100644 --- a/docs/source/en/modular_diffusers/components_manager.md +++ b/docs/source/en/modular_diffusers/components_manager.md @@ -18,12 +18,12 @@ specific language governing permissions and limitations under the License. -The Components Manager is a central model registry and management system in diffusers. It lets you add models then reuse them across multiple pipelines and workflows. It tracks all models in one place with useful metadata such as model size, device placement and loaded adapters (LoRA, IP-Adapter). It has mechanisms in place to prevent duplicate model instances, enables memory-efficient sharing. Most significantly, it offers offloading that works across pipelines — unlike regular DiffusionPipeline offloading which is limited to one pipeline with predefined sequences, the Components Manager automatically manages your device memory across all your models and workflows. +The Components Manager is a central model registry and management system in diffusers. It lets you add models then reuse them across multiple pipelines and workflows. It tracks all models in one place with useful metadata such as model size, device placement and loaded adapters (LoRA, IP-Adapter). It has mechanisms in place to prevent duplicate model instances, enables memory-efficient sharing. Most significantly, it offers offloading that works across pipelines — unlike regular DiffusionPipeline offloading (i.e. `enable_model_cpu_offload` and `enable_sequential_cpu_offload`) which is limited to one pipeline with predefined sequences, the Components Manager automatically manages your device memory across all your models and workflows. ## Basic Operations -Let's start with the fundamental operations. First, create a Components Manager: +Let's start with the most basic operations. First, create a Components Manager: ```py from diffusers import ComponentsManager @@ -208,7 +208,7 @@ The `get_one()` method returns a single component and supports pattern matching - exclusion patterns like `comp.get_one(name="!unet")` to exclude components named "unet" - OR patterns like `comp.get_one(name="unet|vae")` to match either "unet" OR "vae". -You can also filter by collection with `comp.get_one(name="unet", collection="sdxl")` or by load_id. If multiple components match, `get_one()` throws an error. +Optionally, You can add collection and load_id as filters e.g. `comp.get_one(name="unet", collection="sdxl")`. If multiple components match, `get_one()` throws an error. Another useful method is `get_components_by_names()`, which takes a list of names and returns a dictionary mapping names to components. This is particularly helpful with modular pipelines since they provide lists of required component names, and the returned dictionary can be directly passed to `pipeline.update_components()`. @@ -260,7 +260,7 @@ Now let's load all default components and then create a second pipeline that reu ```py # Load all default components ->>> pipe.load_default_components()` +>>> pipe.load_default_components() # Create a second pipeline using the same Components Manager but with a different collection >>> pipe2 = ModularPipeline.from_pretrained("YiYiXu/modular-demo-auto", components_manager=comp, collection="test2") @@ -359,9 +359,9 @@ When enabled, all models start on CPU. The manager moves models to the device ri Now that we've covered the basics of the Components Manager, let's walk through a practical example that shows how to build workflows in a modular setting and use the Components Manager to reuse components across multiple pipelines. This example demonstrates the true power of Modular Diffusers by working with multiple pipelines that can share components. -In this example, we'll generate latents from a text-to-image pipeline, then refine them with an image-to-image pipeline. We will also use Lora and IP-Adapter. +In this example, we'll generate latents from a text-to-image pipeline, then refine them with an image-to-image pipeline. -Let's create a modular text-to-image workflow by separating it into three components: `text_blocks` for encoding prompts, `t2i_blocks` for generating latents, and `decoder_blocks` for creating final images. +Let's create a modular text-to-image workflow by separating it into three workflows: `text_blocks` for encoding prompts, `t2i_blocks` for generating latents, and `decoder_blocks` for creating final images. ```py import torch @@ -374,7 +374,9 @@ text_blocks = t2i_blocks.sub_blocks.pop("text_encoder") decoder_blocks = t2i_blocks.sub_blocks.pop("decode") ``` -Now we will convert them into runnalbe pipelines and set up the Components Manager with auto offloading and organize components under a "t2i" collection: +Now we will convert them into runnalbe pipelines and set up the Components Manager with auto offloading and organize components under a "t2i" collection + +Since we now have 3 different workflows that share components, we create a separate pipeline that serves as a dedicated loader to load all the components, register them to the component manager, and then reuse them across different workflows. ```py from diffusers import ComponentsManager, ModularPipeline @@ -383,20 +385,21 @@ from diffusers import ComponentsManager, ModularPipeline components = ComponentsManager() components.enable_auto_cpu_offload(device="cuda") -# Create pipelines and load components +# Create a new pipeline to load the components t2i_repo = "YiYiXu/modular-demo-auto" t2i_loader_pipe = ModularPipeline.from_pretrained(t2i_repo, components_manager=components, collection="t2i") +# convert the 3 blocks into pipelines and attach the same components manager to all 3 text_node = text_blocks.init_pipeline(t2i_repo, components_manager=components) decoder_node = decoder_blocks.init_pipeline(t2i_repo, components_manager=components) t2i_pipe = t2i_blocks.init_pipeline(t2i_repo, components_manager=components) ``` -Load all components into the Components Manager under the "t2i" collection: +Load all components into the loader pipeline, they should all be automatically registered to Components Manager under the "t2i" collection: ```py # Load all components (including IP-Adapter and ControlNet for later use) -t2i_loader_pipe.load_components(names=t2i_loader_pipe.pretrained_component_names, torch_dtype=torch.float16) +t2i_loader_pipe.load_default_components(torch_dtype=torch.float16) ``` Now distribute the loaded components to each pipeline: @@ -432,7 +435,7 @@ image.save("modular_part2_t2i.png") Let's add a LoRA: ```py -# Load LoRA weights - only the UNet gets the adapter +# Load LoRA weights >>> t2i_loader_pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy_face") >>> components Components: @@ -464,7 +467,8 @@ refiner_blocks = SequentialPipelineBlocks.from_blocks_dict(ALL_BLOCKS["img2img"] refiner_blocks.sub_blocks.pop("image_encoder") refiner_blocks.sub_blocks.pop("decode") -# Create refiner pipeline with different repo and collection +# Create refiner pipeline with different repo and collection, +# Attach the same component manager to it refiner_repo = "YiYiXu/modular_refiner" refiner_pipe = refiner_blocks.init_pipeline(refiner_repo, components_manager=components, collection="refiner") ``` diff --git a/docs/source/en/modular_diffusers/end_to_end_guide.md b/docs/source/en/modular_diffusers/end_to_end_guide.md index 42852c6e6420..6a9e4dc31303 100644 --- a/docs/source/en/modular_diffusers/end_to_end_guide.md +++ b/docs/source/en/modular_diffusers/end_to_end_guide.md @@ -266,27 +266,27 @@ class SDXLDiffDiffLoopBeforeDenoiser(PipelineBlock): "Step within the denoising loop for differential diffusion that prepare the latent input for the denoiser" ) - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("denoising_start"), - ] ++ @property ++ def inputs(self) -> List[Tuple[str, Any]]: ++ return [ ++ InputParam("denoising_start"), ++ ] @property def intermediate_inputs(self) -> List[str]: return [ InputParam("latents", required=True, type_hint=torch.Tensor), - InputParam("original_latents", type_hint=torch.Tensor), - InputParam("diffdiff_masks", type_hint=torch.Tensor), ++ InputParam("original_latents", type_hint=torch.Tensor), ++ InputParam("diffdiff_masks", type_hint=torch.Tensor), ] def __call__(self, components, block_state, i, t): - # Apply differential diffusion logic - if i == 0 and block_state.denoising_start is None: - block_state.latents = block_state.original_latents[:1] - else: - block_state.mask = block_state.diffdiff_masks[i].unsqueeze(0).unsqueeze(1) - block_state.latents = block_state.original_latents[i] * block_state.mask + block_state.latents * (1 - block_state.mask) ++ # Apply differential diffusion logic ++ if i == 0 and block_state.denoising_start is None: ++ block_state.latents = block_state.original_latents[:1] ++ else: ++ block_state.mask = block_state.diffdiff_masks[i].unsqueeze(0).unsqueeze(1) ++ block_state.latents = block_state.original_latents[i] * block_state.mask + block_state.latents * (1 - block_state.mask) # ... rest of existing logic ... ``` diff --git a/docs/source/en/modular_diffusers/loop_sequential_pipeline_blocks.md b/docs/source/en/modular_diffusers/loop_sequential_pipeline_blocks.md new file mode 100644 index 000000000000..e97a133d221a --- /dev/null +++ b/docs/source/en/modular_diffusers/loop_sequential_pipeline_blocks.md @@ -0,0 +1,193 @@ + + +# LoopSequentialPipelineBlocks + + + +🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes. + + + +`LoopSequentialPipelineBlocks` is a subclass of `ModularPipelineBlocks`. It is a multi-block that composes other blocks together in a loop, creating iterative workflows where blocks run multiple times with evolving state. It's particularly useful for denoising loops requiring repeated execution of the same blocks. + + + +Other types of multi-blocks include [SequentialPipelineBlocks](./sequential_pipeline_blocks.md) (for linear workflows) and [AutoPipelineBlocks](./auto_pipeline_blocks.md) (for conditional block selection). For information on creating individual blocks, see the [PipelineBlock guide](./pipeline_block.md). + +Additionally, like all `ModularPipelineBlocks`, `LoopSequentialPipelineBlocks` are definitions/specifications, not runnable pipelines. You need to convert them into a `ModularPipeline` to actually execute them. For information on creating and running pipelines, see the [Modular Pipeline guide](modular_pipeline.md). + + + +You could create a loop using `PipelineBlock` like this: + +```python +class DenoiseLoop(PipelineBlock): + def __call__(self, components, state): + block_state = self.get_block_state(state) + for t in range(block_state.num_inference_steps): + # ... loop logic here + pass + self.set_block_state(state, block_state) + return components, state + +But in this tutorial, we will focus on how to use `LoopSequentialPipelineBlocks` to create a "composable" denoising loop where you can add or remove blocks within the loop or reuse the same loop structure with different block combinations. + +It involves two parts: a **loop wrapper** and **loop blocks** + +* The **loop wrapper** (`LoopSequentialPipelineBlocks`) defines the loop structure, e.g. it defines the iteration variables, and loop configurations such as progress bar. + +* The **loop blocks** are basically standard pipeline blocks you add to the loop wrapper. + - they run sequentially for each iteration of the loop + - they receive the current iteration index as an additional parameter + - they share the same block_state throughout the entire loop + +Unlike regular `SequentialPipelineBlocks` where each block gets its own state, loop blocks share a single state that persists and evolves across iterations. + +We will build a simple loop block to demonstrate these concepts. Creating a loop block involves three steps: +1. defining the loop wrapper class +2. creating the loop blocks +3. adding the loop blocks to the loop wrapper class to create the loop wrapper instance + +**Step 1: Define the Loop Wrapper** + +To create a `LoopSequentialPipelineBlocks` class, you need to define: + +* `loop_inputs`: User input variables (equivalent to `PipelineBlock.inputs`) +* `loop_intermediate_inputs`: Intermediate variables needed from the mutable pipeline state (equivalent to `PipelineBlock.intermediates_inputs`) +* `loop_intermediate_outputs`: New intermediate variables this block will add to the mutable pipeline state (equivalent to `PipelineBlock.intermediates_outputs`) +* `__call__` method: Defines the loop structure and iteration logic + +Here is an example of a loop wrapper: + +```py +import torch +from diffusers.modular_pipelines import LoopSequentialPipelineBlocks, PipelineBlock, InputParam, OutputParam + +class LoopWrapper(LoopSequentialPipelineBlocks): + model_name = "test" + @property + def description(self): + return "I'm a loop!!" + @property + def loop_inputs(self): + return [InputParam(name="num_steps")] + @torch.no_grad() + def __call__(self, components, state): + block_state = self.get_block_state(state) + # Loop structure - can be customized to your needs + for i in range(block_state.num_steps): + # loop_step executes all registered blocks in sequence + components, block_state = self.loop_step(components, block_state, i=i) + self.set_block_state(state, block_state) + return components, state +``` + +**Step 2: Create Loop Blocks** + +Loop blocks are standard `PipelineBlock`s, but their `__call__` method works differently: +* It receives the iteration variable (e.g., `i`) passed by the loop wrapper +* It works directly with `block_state` instead of pipeline state +* No need to call `self.get_block_state()` or `self.set_block_state()` + +```py +class LoopBlock(PipelineBlock): + # this is used to identify the model family, we won't worry about it in this example + model_name = "test" + @property + def inputs(self): + return [InputParam(name="x")] + @property + def intermediate_outputs(self): + # outputs produced by this block + return [OutputParam(name="x")] + @property + def description(self): + return "I'm a block used inside the `LoopWrapper` class" + def __call__(self, components, block_state, i: int): + block_state.x += 1 + return components, block_state +``` + +**Step 3: Combine Everything** + +Finally, assemble your loop by adding the block(s) to the wrapper: + +```py +loop = LoopWrapper.from_blocks_dict({"block1": LoopBlock}) +``` + +Now you've created a loop with one step: + +```py +>>> loop +LoopWrapper( + Class: LoopSequentialPipelineBlocks + + Description: I'm a loop!! + + Sub-Blocks: + [0] block1 (LoopBlock) + Description: I'm a block used inside the `LoopWrapper` class + +) +``` + +It has two inputs: `x` (used at each step within the loop) and `num_steps` used to define the loop. + +```py +>>> print(loop.doc) +class LoopWrapper + + I'm a loop!! + + Inputs: + + x (`None`, *optional*): + + num_steps (`None`, *optional*): + + Outputs: + + x (`None`): +``` + +**Running the Loop:** + +```py +# run the loop +loop_pipeline = loop.init_pipeline() +x = loop_pipeline(num_steps=10, x=0, output="x") +assert x == 10 +``` + +**Adding Multiple Blocks:** + +We can add multiple blocks to run within each iteration. Let's run the loop block twice within each iteration: + +```py +loop = LoopWrapper.from_blocks_dict({"block1": LoopBlock(), "block2": LoopBlock}) +loop_pipeline = loop.init_pipeline() +x = loop_pipeline(num_steps=10, x=0, output="x") +assert x == 20 # Each iteration runs 2 blocks, so 10 iterations * 2 = 20 +``` + +**Key Differences from SequentialPipelineBlocks:** + +The main difference is that loop blocks share the same `block_state` across all iterations, allowing values to accumulate and evolve throughout the loop. Loop blocks could receive additional arguments (like the current iteration index) depending on the loop wrapper's implementation, since the wrapper defines how loop blocks are called. You can easily add, remove, or reorder blocks within the loop without changing the loop logic itself. + +The officially supported denoising loops in Modular Diffusers are implemented using `LoopSequentialPipelineBlocks`. You can explore the actual implementation to see how these concepts work in practice: + +```py +from diffusers.modular_pipelines.stable_diffusion_xl.denoise import StableDiffusionXLDenoiseStep +StableDiffusionXLDenoiseStep() +``` \ No newline at end of file diff --git a/docs/source/en/modular_diffusers/modular_diffusers_states.md b/docs/source/en/modular_diffusers/modular_diffusers_states.md new file mode 100644 index 000000000000..744089fcf676 --- /dev/null +++ b/docs/source/en/modular_diffusers/modular_diffusers_states.md @@ -0,0 +1,59 @@ + + +# PipelineState and BlockState + + + +🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes. + + + +In Modular Diffusers, `PipelineState` and `BlockState` are the core data structures that enable blocks to communicate and share data. The concept is fundamental to understand how blocks interact with each other and the pipeline system. + +In the modular diffusers system, `PipelineState` acts as the global state container that all pipeline blocks operate on. It maintains the complete runtime state of the pipeline and provides a structured way for blocks to read from and write to shared data. + +A `PipelineState` consists of two distinct states: + +- **The immutable state** (i.e. the `inputs` dict) contains a copy of values provided by users. Once a value is added to the immutable state, it cannot be changed. Blocks can read from the immutable state but cannot write to it. + +- **The mutable state** (i.e. the `intermediates` dict) contains variables that are passed between blocks and can be modified by them. + +Here's an example of what a `PipelineState` looks like: + +```py +PipelineState( + inputs={ + 'prompt': 'a cat' + 'guidance_scale': 7.0 + 'num_inference_steps': 25 + }, + intermediates={ + 'prompt_embeds': Tensor(dtype=torch.float32, shape=torch.Size([1, 1, 1, 1])) + 'negative_prompt_embeds': None + }, +) +``` + +Each pipeline blocks define what parts of that state they can read from and write to through their `inputs`, `intermediate_inputs`, and `intermediate_outputs` properties. At run time, they gets a local view (`BlockState`) of the relevant variables it needs from `PipelineState`, performs its operations, and then updates `PipelineState` with any changes. + +For example, if a block defines an input `image`, inside the block's `__call__` method, the `BlockState` would contain: + +```py +BlockState( + image: +) +``` + +You can access the variables directly as attributes: `block_state.image`. + +We will explore more on how blocks interact with pipeline state through their `inputs`, `intermediate_inputs`, and `intermediate_outputs` properties, see the [PipelineBlock guide](./pipeline_block.md). \ No newline at end of file diff --git a/docs/source/en/modular_diffusers/getting_started.md b/docs/source/en/modular_diffusers/modular_pipeline.md similarity index 91% rename from docs/source/en/modular_diffusers/getting_started.md rename to docs/source/en/modular_diffusers/modular_pipeline.md index 4b82d9c85fb6..c4d82306d127 100644 --- a/docs/source/en/modular_diffusers/getting_started.md +++ b/docs/source/en/modular_diffusers/modular_pipeline.md @@ -10,7 +10,7 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> -# Getting Started with Modular Diffusers: A Comprehensive Overview +# ModularPipeline @@ -18,31 +18,34 @@ specific language governing permissions and limitations under the License. -With Modular Diffusers, we introduce a unified pipeline system that simplifies how you work with diffusion models. Instead of creating separate pipelines for each task, Modular Diffusers lets you: +`ModularPipeline` is the main interface for end users to run pipelines in Modular Diffusers. It takes pipeline blocks and converts them into a runnable pipeline that can load models and execute the computation steps. -**Write Only What's New**: You won't need to write an entire pipeline from scratch every time you have a new use case. You can create pipeline blocks just for your new workflow's unique aspects and reuse existing blocks for existing functionalities. +In this guide, we will focus on how to build pipelines using the blocks we officially support at diffusers 🧨. We'll cover how to use predefined blocks and convert them into a `ModularPipeline` for execution. -**Assemble Like LEGO®**: You can mix and match between blocks in flexible ways. This allows you to write dedicated blocks unique to specific workflows, and then assemble different blocks into a pipeline that can be used more conveniently for multiple workflows. + + +This guide shows you how to use predefined blocks. If you want to learn how to create your own pipeline blocks, see the [PipelineBlock guide](pipeline_block.md) for creating individual blocks, and the multi-block guides for connecting them together: +- [SequentialPipelineBlocks](sequential_pipeline_blocks.md) (for linear workflows) +- [LoopSequentialPipelineBlocks](loop_sequential_pipeline_blocks.md) (for iterative workflows) +- [AutoPipelineBlocks](auto_pipeline_blocks.md) (for conditional workflows) -In this guide, we will focus on how to build end-to-end pipelines using blocks we officially support at diffusers 🧨! We will show you how to write your own pipeline blocks and go into more details on how they work under the hood in this [guide](./write_own_pipeline_block.md). For advanced users who want to build complete workflows from scratch, we provide an end-to-end example in the [Developer Guide](./end_to_end.md) that covers everything from writing custom pipeline blocks to deploying your workflow as a UI node. +For information on how data flows through pipelines, see the [PipelineState and BlockState guide](modular_diffusers_states.md). -Let's get started! The Modular Diffusers Framework consists of three main components: -- ModularPipelineBlocks: Building blocks for your workflow, each block defines inputs/outputs and computation steps. These are just definitions and not runnable. -- PipelineState & BlockState: Store and manage data as it flows through the pipeline. -- ModularPipeline: Loads models and runs the computation steps. You convert blocks to pipelines to make them executable. + -## ModularPipelineBlocks + +## Create ModularPipelineBlocks Pipeline blocks are the fundamental building blocks of the Modular Diffusers system. All pipeline blocks inherit from the base class `ModularPipelineBlocks`, including: -- [`PipelineBlock`]: The most granular block - you define the computation logic. +- [`PipelineBlock`]: The most granular block - you define the input/output/components requirements and computation logic. - [`SequentialPipelineBlocks`]: A multi-block composed of multiple blocks that run sequentially, passing outputs as inputs to the next block. - [`LoopSequentialPipelineBlocks`]: A special type of `SequentialPipelineBlocks` that runs the same sequence of blocks multiple times (loops), typically used for iterative processes like denoising steps in diffusion models. - [`AutoPipelineBlocks`]: A multi-block composed of multiple blocks that are selected at runtime based on the inputs. -All blocks have a consistent interface defining their requirements (components, configs, inputs, outputs) and computation logic. They can be defined standalone or combined into larger blocks - They are designed to be assembled into workflows for tasks such as image generation, video creation, and inpainting. However, blocks aren't runnable on thier own and they need to be converted into a a ModularPipeline to actually run. +All blocks have a consistent interface defining their requirements (components, configs, inputs, outputs) and computation logic. They are designed to be assembled into workflows for tasks such as image generation, video creation, and inpainting. However, blocks aren't runnable on thier own and they need to be converted into a a ModularPipeline to actually run. -**Blocks vs Pipelines**: Blocks are just definitions - they define what components, inputs/outputs, and computation logics are needed, but they don't actually run anything. To execute blocks, you need to put them into a `ModularPipeline`. See the [ModularPipeline from ModularPipelineBlocks](#modularpipeline-from-modularpipelineblocks) section for how to create and run pipelines. +**Blocks vs Pipelines**: Blocks are just definitions - they define what components, inputs/outputs, and computation logics are needed, but they don't actually run anything. To execute blocks, you need to put them into a `ModularPipeline`. We will first learn how to create predefined blocks here before talking about how to run them using `ModularPipeline`. It is very easy to use a `ModularPipelineBlocks` officially supported in 🧨 Diffusers @@ -74,9 +77,7 @@ StableDiffusionXLTextEncoderStep( ) ``` -More commonly, you need multiple blocks to build your workflow. You can create a `SequentialPipelineBlocks` using block class presets from 🧨 Diffusers. - -`TEXT2IMAGE_BLOCKS` is a predefined dictionary containing all the blocks needed for a complete text-to-image pipeline (text encoding, denoising, decoding, etc.). We will see more details soon. +More commonly, you need multiple blocks to build your workflow. You can create a `SequentialPipelineBlocks` using block class presets from 🧨 Diffusers. `TEXT2IMAGE_BLOCKS` is a preset containing all the blocks needed for text-to-image generation. ```py from diffusers.modular_pipelines import SequentialPipelineBlocks @@ -179,9 +180,9 @@ Note that both the block classes preset and the `sub_blocks` attribute are `Inse **Add a block:** ```py -# BLOCKS is a block class preset, you need to add class to it +# BLOCKS is dict of block classes, you need to add class to it BLOCKS.insert("block_name", BlockClass, index) -# Add a block instance to the `sub_blocks` attribute +# sub_blocks attribute contains instance, add a block instance to the attribute t2i_blocks.sub_blocks.insert("block_name", block_instance, index) ``` @@ -197,7 +198,7 @@ text_encoder_block = t2i_blocks.sub_blocks.pop("text_encoder") ```py # Replace block class in preset BLOCKS["prepare_latents"] = CustomPrepareLatents -# Replace in sub_blocks attribute +# Replace in sub_blocks attribute using an block instance t2i_blocks.sub_blocks["prepare_latents"] = CustomPrepareLatents() ``` @@ -299,25 +300,10 @@ ALL_BLOCKS = { -We will not go over how to write your own ModularPipelineBlocks but you can learn more about it [here](./write_own_pipeline_block.md). - This covers the essentials of pipeline blocks! You may have noticed that we haven't discussed how to load or run pipeline blocks - that's because **pipeline blocks are not runnable by themselves**. They are essentially **"definitions"** - they define the specifications and computational steps for a pipeline, but they do not contain any model states. To actually run them, you need to convert them into a `ModularPipeline` object. -## PipelineState & BlockState - -`PipelineState` and `BlockState` manage dataflow between pipeline blocks. `PipelineState` acts as the global state container that `ModularPipelineBlocks` operate on - each block gets a local view (`BlockState`) of the relevant variables it needs from `PipelineState`, performs its operations, and then updates `PipelineState` as needed. - - -You typically don't need to manually create or manage these state objects. The `ModularPipeline` automatically creates and manages them for you. However, understanding their roles is important for developing custom pipeline blocks. - - - -## ModularPipeline - -`ModularPipeline` is the main interface to create and execute pipelines in the Modular Diffusers system. - -### Modular Repo +## Modular Repo `ModularPipeline` only works with modular repositories. You can find an example modular repo [here](https://huggingface.co/YiYiXu/modular-diffdiff). @@ -338,13 +324,13 @@ In `modular_model_index.json`, each component entry contains 3 elements: `(libra ```py "text_encoder": [ - null, # library (same as model_index.json) - null, # class (same as model_index.json) + null, # library of actual loaded component (same as in model_index.json) + null, # class of actual loaded componenet (same as in model_index.json) { # loading specs map (unique to modular_model_index.json) "repo": "stabilityai/stable-diffusion-xl-base-1.0", # can be a different repo "revision": null, "subfolder": "text_encoder", - "type_hint": [ # (library, class) for the expected component class + "type_hint": [ # (library, class) for the expected component "transformers", "CLIPTextModel" ], @@ -356,7 +342,7 @@ In `modular_model_index.json`, each component entry contains 3 elements: `(libra Unlike standard repositories where components must be in subfolders within the same repo, modular repositories can fetch components from different repositories based on the `loading_specs_dict`. e.g. the `text_encoder` component will be fetched from the "text_encoder" folder in `stabilityai/stable-diffusion-xl-base-1.0` while other components come from different repositories. -### Creating a `ModularPipeline` from `ModularPipelineBlocks` +## Creating a `ModularPipeline` from `ModularPipelineBlocks` Each `ModularPipelineBlocks` has an `init_pipeline` method that can initialize a `ModularPipeline` object based on its component and configuration specifications. @@ -392,7 +378,7 @@ You can read more about [Components Manager](./components_manager.md) -### Creating a `ModularPipeline` with `from_pretrained` +## Creating a `ModularPipeline` with `from_pretrained` You can create a `ModularPipeline` from a HuggingFace Hub repository with `from_pretrained` method, as long as it's a modular repo: @@ -424,7 +410,7 @@ The `auto_map` tells the pipeline where to find the custom blocks definition - i When `diffdiff_pipeline.blocks` is created, it's based on the `DiffDiffBlocks` definition from the custom code in the repository, allowing you to use specialized blocks that aren't part of the standard diffusers library. -### Loading components into a `ModularPipeline` +## Loading components into a `ModularPipeline` Unlike `DiffusionPipeline`, when you create a `ModularPipeline` instance (whether using `from_pretrained` or converting from pipeline blocks), its components aren't loaded automatically. You need to explicitly load model components using `load_default_components` or `load_components(names=..,)`: @@ -551,7 +537,15 @@ StableDiffusionXLModularPipeline { } ``` -You can see all the components that will be loaded using `from_pretrained` method are listed as entries. Each entry contains 3 elements: `(library, class, loading_specs_dict)`: +You can see all the **pretrained components** that will be loaded using `from_pretrained` method are listed as entries. Each entry contains 3 elements: `(library, class, loading_specs_dict)`: + + + +**Pretrained vs Config-based Components**: Only pretrained components (like models loaded from Hugging Face Hub) appear in the `modular_model_index.json` file at all. Components created with default configurations at initialization (like schedulers, guiders) are not included in the index since they don't need to be loaded from external sources. + +Whether a component is pretrained or config-based is defined in each pipeline block's `expected_components` field using `ComponentSpec` with the `default_creation_method` parameter. See the [PipelineBlock](./pipeline_block.md) guide for more details on how to define component specifications. + + - **`library` and `class`**: Show the actual loaded component info. If `null`, the component is not loaded yet. - **`loading_specs_dict`**: Contains all the information needed to load the component (repo, subfolder, variant, etc.) @@ -584,9 +578,9 @@ There are also a few properties that can provide a quick summary of component lo ['guider', 'image_processor'] ``` -### Modifying Loading Specs +## Modifying Loading Specs -When you call `pipeline.load_components(names=)` or `pipeline.load_default_components()`, it uses the loading specs from the modular repository's `modular_model_index.json`. You can change where components are loaded from by default by modifying the `modular_model_index.json` in the repository. You can change any field in the loading specs: `repo`, `subfolder`, `variant`, `revision`, etc. +When you call `pipeline.load_components(names=)` or `pipeline.load_default_components()`, it uses the loading specs from the modular repository's `modular_model_index.json`. You can change where components are loaded from by default by modifying the `modular_model_index.json` in the repository. Just find the file on the Hub and click edit - you can change any field in the loading specs: `repo`, `subfolder`, `variant`, `revision`, etc. ```py # Original spec in modular_model_index.json @@ -613,15 +607,23 @@ When you call `pipeline.load_components(names=)` or `pipeline.load_default_compo When you call `pipeline.load_components(...)`/`pipeline.load_default_components()`, it will now load from the new repository by default. -### Updating components in a `ModularPipeline` +## Updating components in a `ModularPipeline` Similar to `DiffusionPipeline`, you can load components separately to replace the default ones in the pipeline. In Modular Diffusers, the approach depends on the component type: -- **Pretrained components** (`default_creation_method='from_pretrained'`): Must use `ComponentSpec` to load them, as they get tagged with a unique ID that encodes their loading parameters -- **Config components** (`default_creation_method='from_config'`): These are components that don't need loading specs - they're created during pipeline initialization with default config. To update them, you can either pass the object directly or pass a ComponentSpec directly (which will call `create()` under the hood). +- **Pretrained components** (`default_creation_method='from_pretrained'`): Must use `ComponentSpec` to load them to update the existing one. +- **Config components** (`default_creation_method='from_config'`): These are components that don't need loading specs - they're created during pipeline initialization with default config. To update them, you can either pass the object directly or pass a ComponentSpec directly. + + + +💡 **Component Type Changes**: The component type (pretrained vs config-based) can change when you update components. These types are initially defined in pipeline blocks' `expected_components` field using `ComponentSpec` with `default_creation_method`. See the [Customizing Guidance Techniques](#customizing-guidance-techniques) section for examples of how this works in practice. + + `ComponentSpec` defines how to create or load components and can actually create them using its `create()` method (for ConfigMixin objects) or `load()` method (wrapper around `from_pretrained()`). When a component is loaded with a ComponentSpec, it gets tagged with a unique ID that encodes its creation parameters, allowing you to always extract the original specification using `ComponentSpec.from_component()`. +Now let's look at how to update pretrained components in practice: + So instead of ```py @@ -700,7 +702,7 @@ ComponentSpec( -### Customizing Guidance Techniques +## Customizing Guidance Techniques Guiders are implementations of different [classifier-free guidance](https://huggingface.co/papers/2207.12598) techniques that can be applied during the denoising process to improve generation quality, control, and adherence to prompts. They work by steering the model predictions towards desired directions and away from undesired directions. In diffusers, guiders are implemented as subclasses of `BaseGuidance`. They can easily be integrated into modular pipelines and provide a flexible way to enhance generation quality without modifying the underlying diffusion models. @@ -826,7 +828,7 @@ The component spec has also been updated to reflect the new guider type: ```py >>> t2i_pipeline.get_component_spec("guider") -ComponentSpec(name='guider', type_hint=, description=None, config=FrozenDict([('guidance_scale', 5.0), ('perturbed_guidance_scale', 2.5), ('perturbed_guidance_start', 0.01), ('perturbed_guidance_stop', 0.2), ('perturbed_guidance_layers', None), ('perturbed_guidance_config', LayerSkipConfig(indices=[2, 9], fqn='mid_block.attentions.0.transformer_blocks', skip_attention=False, skip_attention_scores=True, skip_ff=False, dropout=1.0)), ('guidance_rescale', 0.0), ('use_original_formulation', False), ('start', 0.0), ('stop', 1.0), ('_use_default_values', ['use_original_formulation', 'perturbed_guidance_stop', 'stop', 'guidance_rescale', 'start', 'perturbed_guidance_layers', 'perturbed_guidance_start']), ('_class_name', 'PerturbedAttentionGuidance'), ('_diffusers_version', '0.35.0.dev0')]), repo=None, subfolder=None, variant=None, revision=None, default_creation_method='from_config') +ComponentSpec(name='guider', type_hint=, description=None, config=FrozenDict([('guidance_scale', 5.0), ('perturbed_guidance_scale', 2.5), ('perturbed_guidance_start', 0.01), ('perturbed_guidance_stop', 0.2), ('perturbed_guidance_layers', None), ('perturbed_guidance_config', LayerSkipConfig(indices=[2, 9], fqn='mid_block.attentions.0.transformer_blocks', skip_attention=False, skip_attention_scores=True, skip_ff=False, dropout=1.0)), ('guidance_rescale', 0.0), ('use_original_formulation', False), ('start', 0.0), ('stop', 1.0), ('_use_default_values', ['perturbed_guidance_start', 'use_original_formulation', 'perturbed_guidance_layers', 'stop', 'start', 'guidance_rescale', 'perturbed_guidance_stop']), ('_class_name', 'PerturbedAttentionGuidance'), ('_diffusers_version', '0.35.0.dev0')]), repo=None, subfolder=None, variant=None, revision=None, default_creation_method='from_config') ``` However, the "guider" is still not included in the pipeline config and will not be saved into the `modular_model_index.json` since it remains a `from_config` component: @@ -907,7 +909,7 @@ Additionally, you can write your own guider implementations, for example, CFG Ze -### Running a `ModularPipeline` +## Running a `ModularPipeline` The API to run the `ModularPipeline` is very similar to how you would run a regular `DiffusionPipeline`: @@ -926,14 +928,14 @@ Under the hood, `ModularPipeline`'s `__call__` method is a wrapper around the pi You can inspect the docstring of a `ModularPipeline` to check what arguments the pipeline accepts and how to specify the `output` you want. It will list all available outputs (basically everything in the intermediate pipeline state) so you can choose from the list. -**Important**: It is important to always check the docstring because arguments can be different from standard pipelines that you're familar with. For example, in Modular Diffusers we standardized controlnet image input as `control_image`, but regular pipelines have inconsistencies over the names, e.g. controlnet text-to-image uses `image` while SDXL controlnet img2img uses `control_image`. - -**Note**: The `output` list might be longer than you expected - it includes everything in the intermediate state that you can choose to return. Most of the time, you'll just want `output="images"` or `output="latents"`. - ```py t2i_pipeline.doc ``` +**Important**: It is important to always check the docstring because arguments can be different from standard pipelines that you're familar with. For example, in Modular Diffusers we standardized controlnet image input as `control_image`, but regular pipelines have inconsistencies over the names, e.g. controlnet text-to-image uses `image` while SDXL controlnet img2img uses `control_image`. + +**Note**: The `output` list might be longer than you expected - it includes everything in the intermediate state that you can choose to return. Most of the time, you'll just want `output="images"` or `output="latents"`. + #### Text-to-Image, Image-to-Image, and Inpainting @@ -1072,7 +1074,7 @@ StableDiffusionXLAutoControlnetStep( -💡 **Auto Blocks**: This is first time we meet a Auto Blocks! `AutoPipelineBlocks` automatically adapt to your inputs by combining multiple workflows with conditional logic. This is why one convenient block can work for all tasks and controlnet types. See the [Auto Blocks Guide](https://huggingface.co/docs/diffusers/modular_diffusers/write_own_pipeline_block#autopipelineblocks) for more details. +💡 **Auto Blocks**: This is first time we meet a Auto Blocks! `AutoPipelineBlocks` automatically adapt to your inputs by combining multiple workflows with conditional logic. This is why one convenient block can work for all tasks and controlnet types. See the [Auto Blocks Guide](./auto_pipeline_blocks.md) for more details. diff --git a/docs/source/en/modular_diffusers/overview.md b/docs/source/en/modular_diffusers/overview.md new file mode 100644 index 000000000000..359fe5823dae --- /dev/null +++ b/docs/source/en/modular_diffusers/overview.md @@ -0,0 +1,42 @@ + + +# Getting Started with Modular Diffusers + + + +🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes. + + + +With Modular Diffusers, we introduce a unified pipeline system that simplifies how you work with diffusion models. Instead of creating separate pipelines for each task, Modular Diffusers lets you: + +**Write Only What's New**: You won't need to write an entire pipeline from scratch every time you have a new use case. You can create pipeline blocks just for your new workflow's unique aspects and reuse existing blocks for existing functionalities. + +**Assemble Like LEGO®**: You can mix and match between blocks in flexible ways. This allows you to write dedicated blocks unique to specific workflows, and then assemble different blocks into a pipeline that can be used more conveniently for multiple workflows. + + +Here's how our guides are organized to help you navigate the Modular Diffusers documentation: + +### 🚀 Running Pipelines +- **[modular_pipeline.md](./modular_pipeline.md)** - How to use predefined blocks to build a pipeline and run it +- **[components_manager.md](./components_manager.md)** - How to manage and reuse components across multiple pipelines + +### 📚 Creating PipelineBlocks +- **[modular_diffusers_states.md](./modular_diffusers_states.md)** - Understanding PipelineState and BlockState +- **[pipeline_block.md](./pipeline_block.md)** - How to write custom PipelineBlocks +- **[sequential_pipeline_blocks.md](sequential_pipeline_blocks.md)** - Connecting blocks in sequence +- **[loop_sequential_pipeline_blocks.md](./loop_sequential_pipeline_blocks.md)** - Creating iterative workflows +- **[auto_pipeline_blocks.md](./auto_pipeline_blocks.md)** - Conditional block selection + +### 🎯 Practical Examples +- **[end_to_end_guide.md](./end_to_end_guide.md)** - Complete end-to-end examples and practical workflows diff --git a/docs/source/en/modular_diffusers/pipeline_block.md b/docs/source/en/modular_diffusers/pipeline_block.md new file mode 100644 index 000000000000..20f46e928c28 --- /dev/null +++ b/docs/source/en/modular_diffusers/pipeline_block.md @@ -0,0 +1,279 @@ + + +# PipelineBlock + + + +🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes. + + + +In Modular Diffusers, you build your workflow using `ModularPipelineBlocks`. We support 4 different types of blocks: `PipelineBlock`, `SequentialPipelineBlocks`, `LoopSequentialPipelineBlocks`, and `AutoPipelineBlocks`. Among them, `PipelineBlock` is the most fundamental building block of the whole system - it's like a brick in a Lego system. These blocks are designed to easily connect with each other, allowing for modular construction of creative and potentially very complex workflows. + + +**Important**: `PipelineBlock`s are definitions/specifications, not runnable pipelines. They define what a block should do and what data it needs, but you need to convert them into a `ModularPipeline` to actually execute them. For information on creating and running pipelines, see the [Modular Pipeline guide](./modular_pipeline.md). + + +In this tutorial, we will focus on how to write a basic `PipelineBlock` and how it interacts with the pipeline state. + +## PipelineState + +Before we dive into creating `PipelineBlock`s, make sure you have a basic understanding of `PipelineState`. It acts as the global state container that all blocks operate on - each block gets a local view (`BlockState`) of the relevant variables it needs from `PipelineState`, performs its operations, and then updates `PipelineState` with any changes. See the [PipelineState and BlockState guide](./modular_diffusers_states.md) for more details. + +## Creating a `PipelineBlock` + +To write a `PipelineBlock` class, you need to define a few properties that determine how your block interacts with the pipeline state. Understanding these properties is crucial - they define what data your block can access and what it can produce. + +The three main properties you need to define are: +- `inputs`: Immutable values from the user that cannot be modified +- `intermediate_inputs`: Mutable values from previous blocks that can be read and modified +- `intermediate_outputs`: New values your block creates for subsequent blocks + +Let's explore each one and understand how they work with the pipeline state. + +**Inputs: Immutable User Values** + +Inputs are variables your block needs from the immutable pipeline state - these are user-provided values that cannot be modified by any block. You define them using `InputParam`: + +```py +user_inputs = [ + InputParam(name="image", type_hint="PIL.Image", description="raw input image to process") +] +``` + +When you list something as an input, you're saying "I need this value directly from the end user, and I will talk to them directly, telling them what I need in the 'description' field. They will provide it and it will come to me unchanged." + +This is especially useful for raw values that serve as the "source of truth" in your workflow. For example, with a raw image, many workflows require preprocessing steps like resizing that a previous block might have performed. But in many cases, you also want the raw PIL image. In some inpainting workflows, you need the original image to overlay with the generated result for better control and consistency. + +**Intermediate Inputs: Mutable Values from Previous Blocks** + +Intermediate inputs are variables your block needs from the mutable pipeline state - these are values that can be read and modified. They're typically created by previous blocks, but could also be directly provided by the user if not the case: + +```py +user_intermediate_inputs = [ + InputParam(name="processed_image", type_hint="torch.Tensor", description="image that has been preprocessed and normalized"), +] +``` + +When you list something as an intermediate input, you're saying "I need this value, but I want to work with a different block that has already created it. I already know for sure that I can get it from this other block, but it's okay if other developers want use something different." + +**Intermediate Outputs: New Values for Subsequent Blocks** + +Intermediate outputs are new variables your block creates and adds to the mutable pipeline state so they can be used by subsequent blocks: + +```py +user_intermediate_outputs = [ + OutputParam(name="image_latents", description="latents representing the image") +] +``` + +Intermediate inputs and intermediate outputs work together like Lego studs and anti-studs - they're the connection points that make blocks modular. When one block produces an intermediate output, it becomes available as an intermediate input for subsequent blocks. This is where the "modular" nature of the system really shines - blocks can be connected and reconnected in different ways as long as their inputs and outputs match. + +**The `__call__` Method Structure** + +Your `PipelineBlock`'s `__call__` method should follow this structure: + +```py +def __call__(self, components, state): + # Get a local view of the state variables this block needs + block_state = self.get_block_state(state) + + # Your computation logic here + # block_state contains all your inputs and intermediate_inputs + # You can access them like: block_state.image, block_state.processed_image + + # Update the pipeline state with your updated block_states + self.set_block_state(state, block_state) + return components, state +``` + +The `block_state` object contains all the variables you defined in `inputs` and `intermediate_inputs`, making them easily accessible for your computation. + +**Components and Configs** + +You can define the components and pipeline-level configs your block needs using `ComponentSpec` and `ConfigSpec`: + +```py +from diffusers import ComponentSpec, ConfigSpec + +# Define components your block needs +expected_components = [ + ComponentSpec(name="unet", type_hint=UNet2DConditionModel), + ComponentSpec(name="scheduler", type_hint=EulerDiscreteScheduler) +] + +# Define pipeline-level configs +expected_config = [ + ConfigSpec("force_zeros_for_empty_prompt", True) +] +``` + +**Components**: In the `ComponentSpec`, you must provide a `name` and ideally a `type_hint`. You can also specify a `default_creation_method` to indicate whether the component should be loaded from a pretrained model or created with default configurations. The actual loading details (`repo`, `subfolder`, `variant` and `revision` fields) are typically specified when creating the pipeline, as we covered in the [Modular Pipeline Guide](./modular_pipeline.md). + +**Configs**: Pipeline-level settings that control behavior across all blocks. + +When you convert your blocks into a pipeline using `blocks.init_pipeline()`, the pipeline collects all component requirements from the blocks and fetches the loading specs from the modular repository. The components are then made available to your block in the `components` argument of the `__call__` method. + +That's all you need to define in order to create a `PipelineBlock`. There is no hidden complexity. In fact we are going to create a helper function that take exactly these variables as input and return a pipeline block. We will use this helper function through out the tutorial to create test blocks + +Note that for `__call__` method, the only part you should implement differently is the part between `self.get_block_state()` and `self.set_block_state()`, which can be abstracted into a simple function that takes `block_state` and returns the updated state. Our helper function accepts a `block_fn` that does exactly that. + +**Helper Function** + +```py +from diffusers.modular_pipelines import PipelineBlock, InputParam, OutputParam +import torch + +def make_block(inputs=[], intermediate_inputs=[], intermediate_outputs=[], block_fn=None, description=None): + class TestBlock(PipelineBlock): + model_name = "test" + + @property + def inputs(self): + return inputs + + @property + def intermediate_inputs(self): + return intermediate_inputs + + @property + def intermediate_outputs(self): + return intermediate_outputs + + @property + def description(self): + return description if description is not None else "" + + def __call__(self, components, state): + block_state = self.get_block_state(state) + if block_fn is not None: + block_state = block_fn(block_state, state) + self.set_block_state(state, block_state) + return components, state + + return TestBlock +``` + +## Example: Creating a Simple Pipeline Block + +Let's create a simple block to see how these definitions interact with the pipeline state. To better understand what's happening, we'll print out the states before and after updates to inspect them: + +```py +inputs = [ + InputParam(name="image", type_hint="PIL.Image", description="raw input image to process") +] + +intermediate_inputs = [InputParam(name="batch_size", type_hint=int)] + +intermediate_outputs = [ + OutputParam(name="image_latents", description="latents representing the image") +] + +def image_encoder_block_fn(block_state, pipeline_state): + print(f"pipeline_state (before update): {pipeline_state}") + print(f"block_state (before update): {block_state}") + + # Simulate processing the image + block_state.image = torch.randn(1, 3, 512, 512) + block_state.batch_size = block_state.batch_size * 2 + block_state.processed_image = [torch.randn(1, 3, 512, 512)] * block_state.batch_size + block_state.image_latents = torch.randn(1, 4, 64, 64) + + print(f"block_state (after update): {block_state}") + return block_state + +# Create a block with our definitions +image_encoder_block_cls = make_block( + inputs=inputs, + intermediate_inputs=intermediate_inputs, + intermediate_outputs=intermediate_outputs, + block_fn=image_encoder_block_fn, + description="Encode raw image into its latent presentation" +) +image_encoder_block = image_encoder_block_cls() +pipe = image_encoder_block.init_pipeline() +``` + +Let's check the pipeline's docstring to see what inputs it expects: +```py +>>> print(pipe.doc) +class TestBlock + + Encode raw image into its latent presentation + + Inputs: + + image (`PIL.Image`, *optional*): + raw input image to process + + batch_size (`int`, *optional*): + + Outputs: + + image_latents (`None`): + latents representing the image +``` + +Notice that `batch_size` appears as an input even though we defined it as an intermediate input. This happens because no previous block provided it, so the pipeline makes it available as a user input. However, unlike regular inputs, this value goes directly into the mutable intermediate state. + +Now let's run the pipeline: + +```py +from diffusers.utils import load_image + +image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/image_of_squirrel_painting.png") +state = pipe(image=image, batch_size=2) +print(f"pipeline_state (after update): {state}") +``` +```out +pipeline_state (before update): PipelineState( + inputs={ + image: + }, + intermediates={ + batch_size: 2 + }, +) +block_state (before update): BlockState( + image: + batch_size: 2 +) + +block_state (after update): BlockState( + image: Tensor(dtype=torch.float32, shape=torch.Size([1, 3, 512, 512])) + batch_size: 4 + processed_image: List[4] of Tensors with shapes [torch.Size([1, 3, 512, 512]), torch.Size([1, 3, 512, 512]), torch.Size([1, 3, 512, 512]), torch.Size([1, 3, 512, 512])] + image_latents: Tensor(dtype=torch.float32, shape=torch.Size([1, 4, 64, 64])) +) +pipeline_state (after update): PipelineState( + inputs={ + image: + }, + intermediates={ + batch_size: 4 + image_latents: Tensor(dtype=torch.float32, shape=torch.Size([1, 4, 64, 64])) + }, +) +``` + +**Key Observations:** + +1. **Before the update**: `image` (the input) goes to the immutable inputs dict, while `batch_size` (the intermediate_input) goes to the mutable intermediates dict, and both are available in `block_state`. + +2. **After the update**: + - **`image` (inputs)** changed in `block_state` but not in `pipeline_state` - this change is local to the block only. + - **`batch_size (intermediate_inputs)`** was updated in both `block_state` and `pipeline_state` - this change affects subsequent blocks (we didn't need to declare it as an intermediate output since it was already in the intermediates dict) + - **`image_latents (intermediate_outputs)`** was added to `pipeline_state` because it was declared as an intermediate output + - **`processed_image`** was not added to `pipeline_state` because it wasn't declared as an intermediate output + +Understanding how to create `PipelineBlock`s is fundamental to building modular workflows in Modular Diffusers. Remember that `PipelineBlock`s are definitions/specifications - they define what a block should do and what data it needs, but you need to convert them into a `ModularPipeline` to actually execute them. For information on creating and running pipelines, see the [Modular Pipeline guide](modular_pipeline.md). \ No newline at end of file diff --git a/docs/source/en/modular_diffusers/sequential_pipeline_blocks.md b/docs/source/en/modular_diffusers/sequential_pipeline_blocks.md new file mode 100644 index 000000000000..a683f0d0659a --- /dev/null +++ b/docs/source/en/modular_diffusers/sequential_pipeline_blocks.md @@ -0,0 +1,189 @@ + + +# SequentialPipelineBlocks + + + +🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes. + + + +`SequentialPipelineBlocks` is a subclass of `ModularPipelineBlocks`. Unlike `PipelineBlock`, it is a multi-block that composes other blocks together in sequence, creating modular workflows where data flows from one block to the next. It's one of the most common ways to build complex pipelines by combining simpler building blocks. + + + +Other types of multi-blocks include [AutoPipelineBlocks](auto_pipeline_blocks.md) (for conditional block selection) and [LoopSequentialPipelineBlocks](loop_sequential_pipeline_blocks.md) (for iterative workflows). For information on creating individual blocks, see the [PipelineBlock guide](pipeline_block.md). + +Additionally, like all `ModularPipelineBlocks`, `SequentialPipelineBlocks` are definitions/specifications, not runnable pipelines. You need to convert them into a `ModularPipeline` to actually execute them. For information on creating and running pipelines, see the [Modular Pipeline guide](modular_pipeline.md). + + + +In this tutorial, we will focus on how to create `SequentialPipelineBlocks` and how blocks connect and work together. + +The key insight is that blocks connect through their intermediate inputs and outputs - the "studs and anti-studs" we discussed in the [PipelineBlock guide](pipeline_block.md). When one block produces an intermediate output, it becomes available as an intermediate input for subsequent blocks. + +Let's explore this through an example. We will use the same helper function from the PipelineBlock guide to create blocks. + +```py +from diffusers.modular_pipelines import PipelineBlock, InputParam, OutputParam +import torch + +def make_block(inputs=[], intermediate_inputs=[], intermediate_outputs=[], block_fn=None, description=None): + class TestBlock(PipelineBlock): + model_name = "test" + + @property + def inputs(self): + return inputs + + @property + def intermediate_inputs(self): + return intermediate_inputs + + @property + def intermediate_outputs(self): + return intermediate_outputs + + @property + def description(self): + return description if description is not None else "" + + def __call__(self, components, state): + block_state = self.get_block_state(state) + if block_fn is not None: + block_state = block_fn(block_state, state) + self.set_block_state(state, block_state) + return components, state + + return TestBlock +``` + +Let's create a block that produces `batch_size`, which we'll call "input_block": + +```py +def input_block_fn(block_state, pipeline_state): + + batch_size = len(block_state.prompt) + block_state.batch_size = batch_size * block_state.num_images_per_prompt + + return block_state + +input_block_cls = make_block( + inputs=[ + InputParam(name="prompt", type_hint=list, description="list of text prompts"), + InputParam(name="num_images_per_prompt", type_hint=int, description="number of images per prompt") + ], + intermediate_outputs=[ + OutputParam(name="batch_size", description="calculated batch size") + ], + block_fn=input_block_fn, + description="A block that determines batch_size based on the number of prompts and num_images_per_prompt argument." +) +input_block = input_block_cls() +``` + +Now let's create a second block that uses the `batch_size` from the first block: + +```py +def image_encoder_block_fn(block_state, pipeline_state): + # Simulate processing the image + block_state.image = torch.randn(1, 3, 512, 512) + block_state.batch_size = block_state.batch_size * 2 + block_state.image_latents = torch.randn(1, 4, 64, 64) + return block_state + +image_encoder_block_cls = make_block( + inputs=[ + InputParam(name="image", type_hint="PIL.Image", description="raw input image to process") + ], + intermediate_inputs=[ + InputParam(name="batch_size", type_hint=int) + ], + intermediate_outputs=[ + OutputParam(name="image_latents", description="latents representing the image") + ], + block_fn=image_encoder_block_fn, + description="Encode raw image into its latent presentation" +) +image_encoder_block = image_encoder_block_cls() +``` + +Now let's connect these blocks to create a `SequentialPipelineBlocks`: + +```py +from diffusers.modular_pipelines import SequentialPipelineBlocks, InsertableDict + +# Define a dict mapping block names to block instances +blocks_dict = InsertableDict() +blocks_dict["input"] = input_block +blocks_dict["image_encoder"] = image_encoder_block + +# Create the SequentialPipelineBlocks +blocks = SequentialPipelineBlocks.from_blocks_dict(blocks_dict) +``` + +Now you have a `SequentialPipelineBlocks` with 2 blocks: + +```py +>>> blocks +SequentialPipelineBlocks( + Class: ModularPipelineBlocks + + Description: + + + Sub-Blocks: + [0] input (TestBlock) + Description: A block that determines batch_size based on the number of prompts and num_images_per_prompt argument. + + [1] image_encoder (TestBlock) + Description: Encode raw image into its latent presentation + +) +``` + +When you inspect `blocks.doc`, you can see that `batch_size` is not listed as an input. The pipeline automatically detects that the `input_block` can produce `batch_size` for the `image_encoder_block`, so it doesn't ask the user to provide it. + +```py +>>> print(blocks.doc) +class SequentialPipelineBlocks + + Inputs: + + prompt (`None`, *optional*): + + num_images_per_prompt (`None`, *optional*): + + image (`PIL.Image`, *optional*): + raw input image to process + + Outputs: + + batch_size (`None`): + + image_latents (`None`): + latents representing the image +``` + +At runtime, you have data flow like this: + +![Data Flow Diagram](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/Editor%20_%20Mermaid%20Chart-2025-06-30-092631.png) + +**How SequentialPipelineBlocks Works:** + +1. Blocks are executed in the order they're registered in the `blocks_dict` +2. Outputs from one block become available as intermediate inputs to all subsequent blocks +3. The pipeline automatically figures out which values need to be provided by the user and which will be generated by previous blocks +4. Each block maintains its own behavior and operates through its defined interface, while collectively these interfaces determine what the entire pipeline accepts and produces + +What happens within each block follows the same pattern we described earlier: each block gets its own `block_state` with the relevant inputs and intermediate inputs, performs its computation, and updates the pipeline state with its intermediate outputs. \ No newline at end of file diff --git a/docs/source/en/modular_diffusers/write_own_pipeline_block.md b/docs/source/en/modular_diffusers/write_own_pipeline_block.md deleted file mode 100644 index ae2d819e7f61..000000000000 --- a/docs/source/en/modular_diffusers/write_own_pipeline_block.md +++ /dev/null @@ -1,817 +0,0 @@ - - -# Writing Your Own Pipeline Blocks - - - -🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes. - - - -In Modular Diffusers, you build your workflow using `ModularPipelineBlocks`. We support 4 different types of blocks: `PipelineBlock`, `SequentialPipelineBlocks`, `LoopSequentialPipelineBlocks`, and `AutoPipelineBlocks`. Among them, `PipelineBlock` is the most fundamental building block of the whole system - it's like a brick in a Lego system. These blocks are designed to easily connect with each other, allowing for modular construction of creative and potentially very complex workflows. - -In this tutorial, we will focus on how to write a basic `PipelineBlock` and how it interacts with other components in the system. We will also cover how to connect them together using the multi-blocks: `SequentialPipelineBlocks`, `LoopSequentialPipelineBlocks`, and `AutoPipelineBlocks`. - - -## Understanding the Foundation: `PipelineState` - -Before we dive into creating `PipelineBlock`s, we need to have a basic understanding of `PipelineState` - the core data structure that all blocks operate on. This concept is fundamental to understanding how blocks interact with each other and the pipeline system. - -In the modular diffusers system, `PipelineState` acts as the global state container that `PipelineBlock`s operate on - each block gets a local view (`BlockState`) of the relevant variables it needs from `PipelineState`, performs its operations, and then updates `PipelineState` with any changes. - -While `PipelineState` maintains the complete runtime state of the pipeline, `PipelineBlock`s define what parts of that state they can read from and write to through their `input`s, `intermediates_inputs`, and `intermediates_outputs` properties. - -A `PipelineState` consists of two distinct states: -- The **immutable state** (i.e. the `inputs` dict) contains a copy of values provided by users. Once a value is added to the immutable state, it cannot be changed. Blocks can read from the immutable state but cannot write to it. -- The **mutable state** (i.e. the `intermediates` dict) contains variables that are passed between blocks and can be modified by them. - -Here's an example of what a `PipelineState` looks like: - -``` -PipelineState( - inputs={ - prompt: 'a cat' - guidance_scale: 7.0 - num_inference_steps: 25 - }, - intermediates={ - prompt_embeds: Tensor(dtype=torch.float32, shape=torch.Size([1, 1, 1, 1])) - negative_prompt_embeds: None - }, -``` - -## Creating a `PipelineBlock` - -To write a `PipelineBlock` class, you need to define a few properties that determine how your block interacts with the pipeline state. Understanding these properties is crucial - they define what data your block can access and what it can produce. - -The three main properties you need to define are: -- `inputs`: Immutable values from the user that cannot be modified -- `intermediate_inputs`: Mutable values from previous blocks that can be read and modified -- `intermediate_outputs`: New values your block creates for subsequent blocks - -Let's explore each one and understand how they work with the pipeline state. - -**Inputs: Immutable User Values** - -Inputs are variables your block needs from the immutable pipeline state - these are user-provided values that cannot be modified by any block. You define them using `InputParam`: - -```py -user_inputs = [ - InputParam(name="image", type_hint="PIL.Image", description="raw input image to process") -] -``` - -When you list something as an input, you're saying "I need this value directly from the end user, and I will talk to them directly, telling them what I need in the 'description' field. They will provide it and it will come to me unchanged." - -This is especially useful for raw values that serve as the "source of truth" in your workflow. For example, with a raw image, many workflows require preprocessing steps like resizing that a previous block might have performed. But in many cases, you also want the raw PIL image. In some inpainting workflows, you need the original image to overlay with the generated result for better control and consistency. - -**Intermediate Inputs: Mutable Values from Previous Blocks** - -Intermediate inputs are variables your block needs from the mutable pipeline state - these are values that can be read and modified. They're typically created by previous blocks, but could also be directly provided by the user if not the case: - -```py -user_intermediate_inputs = [ - InputParam(name="processed_image", type_hint="torch.Tensor", description="image that has been preprocessed and normalized"), -] -``` - -When you list something as an intermediate input, you're saying "I need this value, but I want to work with a different block that has already created it. I already know for sure that I can get it from this other block, but it's okay if other developers want use something different." - -**Intermediate Outputs: New Values for Subsequent Blocks** - -Intermediate outputs are new variables your block creates and adds to the mutable pipeline state so they can be used by subsequent blocks: - -```py -user_intermediate_outputs = [ - OutputParam(name="image_latents", description="latents representing the image") -] -``` - -Intermediate inputs and intermediate outputs work together like Lego studs and anti-studs - they're the connection points that make blocks modular. When one block produces an intermediate output, it becomes available as an intermediate input for subsequent blocks. This is where the "modular" nature of the system really shines - blocks can be connected and reconnected in different ways as long as their inputs and outputs match. We will see more how they connect when we talk about multi-blocks. - -**The `__call__` Method Structure** - -Your `PipelineBlock`'s `__call__` method should follow this structure: - -```py -def __call__(self, components, state): - # Get a local view of the state variables this block needs - block_state = self.get_block_state(state) - - # Your computation logic here - # block_state contains all your inputs and intermediate_inputs - # You can access them like: block_state.image, block_state.processed_image - - # Update the pipeline state with your updated block_states - self.set_block_state(state, block_state) - return components, state -``` - -The `block_state` object contains all the variables you defined in `inputs` and `intermediate_inputs`, making them easily accessible for your computation. - -**Components and Configs** - -You can define the components and pipeline-level configs your block needs using `ComponentSpec` and `ConfigSpec`: - -```py -from diffusers import ComponentSpec, ConfigSpec - -# Define components your block needs -expected_components = [ - ComponentSpec(name="unet", type_hint=UNet2DConditionModel), - ComponentSpec(name="scheduler", type_hint=EulerDiscreteScheduler) -] - -# Define pipeline-level configs -expected_config = [ - ConfigSpec("force_zeros_for_empty_prompt", True) -] -``` - -**Components**: In the `ComponentSpec`, You must provide a `name` and ideally a `type_hint`. The actual loading details (`repo`, `subfolder`, `variant` and `revision` fields) are typically specified when creating the pipeline, as we covered in the [Getting Started Guide](https://huggingface.co/docs/diffusers/en/modular_diffusers/getting_started#loading-components-into-a-modularpipeline). - -**Configs**: Simple pipeline-level settings that control behavior across all blocks. - -When you convert your blocks into a pipeline using `blocks.init_pipeline()`, the pipeline collects all component requirements from the blocks and fetches the loading specs from the modular repository. The components are then made available to your block in the `components` argument of the `__call__` method. - -That's all you need to define in order to create a `PipelineBlock`. There is no hidden complexity. In fact we are going to create a helper function that take exactly these variables as input and return a pipeline block. We will use this helper function through out the tutorial to create test blocks - -Note that for `__call__` method, the only part you should implement differently is the part between `self.get_block_state()` and `self.set_block_state()`, which can be abstracted into a simple function that takes `block_state` and returns the updated state. Our helper function accepts a `block_fn` that does exactly that. - -**Helper Function** - -```py -from diffusers.modular_pipelines import PipelineBlock, InputParam, OutputParam -import torch - -def make_block(inputs=[], intermediate_inputs=[], intermediate_outputs=[], block_fn=None, description=None): - class TestBlock(PipelineBlock): - model_name = "test" - - @property - def inputs(self): - return inputs - - @property - def intermediate_inputs(self): - return intermediate_inputs - - @property - def intermediate_outputs(self): - return intermediate_outputs - - @property - def description(self): - return description if description is not None else "" - - def __call__(self, components, state): - block_state = self.get_block_state(state) - if block_fn is not None: - block_state = block_fn(block_state, state) - self.set_block_state(state, block_state) - return components, state - - return TestBlock -``` - - -Let's create a simple block to see how these definitions interact with the pipeline state. To better understand what's happening, we'll print out the states before and after updates to inspect them: - -```py -inputs = [ - InputParam(name="image", type_hint="PIL.Image", description="raw input image to process") -] - -intermediate_inputs = [InputParam(name="batch_size", type_hint=int)] - -intermediate_outputs = [ - OutputParam(name="image_latents", description="latents representing the image") -] - -def image_encoder_block_fn(block_state, pipeline_state): - print(f"pipeline_state (before update): {pipeline_state}") - print(f"block_state (before update): {block_state}") - - # Simulate processing the image - block_state.image = torch.randn(1, 3, 512, 512) - block_state.batch_size = block_state.batch_size * 2 - block_state.processed_image = [torch.randn(1, 3, 512, 512)] * block_state.batch_size - block_state.image_latents = torch.randn(1, 4, 64, 64) - - print(f"block_state (after update): {block_state}") - return block_state - -# Create a block with our definitions -image_encoder_block_cls = make_block( - inputs=inputs, - intermediate_inputs=intermediate_inputs, - intermediate_outputs=intermediate_outputs, - block_fn=image_encoder_block_fn, - description=" Encode raw image into its latent presentation" -) -image_encoder_block = image_encoder_block_cls() -pipe = image_encoder_block.init_pipeline() -``` - -Let's check the pipeline's docstring to see what inputs it expects: -```py ->>> print(pipe.doc) -class TestBlock - - Encode raw image into its latent presentation - - Inputs: - - image (`PIL.Image`, *optional*): - raw input image to process - - batch_size (`int`, *optional*): - - Outputs: - - image_latents (`None`): - latents representing the image -``` - -Notice that `batch_size` appears as an input even though we defined it as an intermediate input. This happens because no previous block provided it, so the pipeline makes it available as a user input. However, unlike regular inputs, this value goes directly into the mutable intermediate state. - -Now let's run the pipeline: - -```py -from diffusers.utils import load_image - -image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/image_of_squirrel_painting.png") -state = pipe(image=image, batch_size=2) -print(f"pipeline_state (after update): {state}") -``` -```out -pipeline_state (before update): PipelineState( - inputs={ - image: - }, - intermediates={ - batch_size: 2 - }, -) -block_state (before update): BlockState( - image: - batch_size: 2 -) - -block_state (after update): BlockState( - image: Tensor(dtype=torch.float32, shape=torch.Size([1, 3, 512, 512])) - batch_size: 4 - processed_image: List[4] of Tensors with shapes [torch.Size([1, 3, 512, 512]), torch.Size([1, 3, 512, 512]), torch.Size([1, 3, 512, 512]), torch.Size([1, 3, 512, 512])] - image_latents: Tensor(dtype=torch.float32, shape=torch.Size([1, 4, 64, 64])) -) -pipeline_state (after update): PipelineState( - inputs={ - image: - }, - intermediates={ - batch_size: 4 - image_latents: Tensor(dtype=torch.float32, shape=torch.Size([1, 4, 64, 64])) - }, -) -``` -**Key Observations:** - -1. **Before the update**: `image` (the input) goes to the immutable inputs dict, while `batch_size` (the intermediate_input) goes to the mutable intermediates dict, and both are available in `block_state`. - -2. **After the update**: - - **`image` (inputs)** changed in `block_state` but not in `pipeline_state` - this change is local to the block only. - - **`batch_size (intermediate_inputs)`** was updated in both `block_state` and `pipeline_state` - this change affects subsequent blocks (we didn't need to declare it as an intermediate output since it was already in the intermediates dict) - - **`image_latents (intermediate_outputs)`** was added to `pipeline_state` because it was declared as an intermediate output - - **`processed_image`** was not added to `pipeline_state` because it wasn't declared as an intermediate output - -I hope by now you have a basic idea about how `PipelineBlock` manages state through inputs, intermediate inputs, and intermediate outputs. The real power comes when we connect multiple blocks together - their intermediate outputs become intermediate inputs for subsequent blocks, creating modular workflows. Let's explore how to build these connections using multi-blocks like `SequentialPipelineBlocks`. - -## Create a `SequentialPipelineBlocks` - -I assume that you're already familiar with `SequentialPipelineBlocks` and how to create them with the `from_blocks_dict` API. It's one of the most common ways to use Modular Diffusers, and we've covered it pretty well in the [Getting Started Guide](https://huggingface.co/docs/diffusers/pr_9672/en/modular_diffusers/getting_started#modularpipelineblocks). - -But how do blocks actually connect and work together? Understanding this is crucial for building effective modular workflows. Let's explore this through an example. - -**How Blocks Connect in SequentialPipelineBlocks:** - -The key insight is that blocks connect through their intermediate inputs and outputs - the "studs and anti-studs" we discussed earlier. Let's expand on our example to create a new block that produces `batch_size`, which we'll call "input_block": - -```py -def input_block_fn(block_state, pipeline_state): - - batch_size = len(block_state.prompt) - block_state.batch_size = batch_size * block_state.num_images_per_prompt - - return block_state - -input_block_cls = make_block( - inputs=[ - InputParam(name="prompt", type_hint=list, description="list of text prompts"), - InputParam(name="num_images_per_prompt", type_hint=int, description="number of images per prompt") - ], - intermediate_outputs=[ - OutputParam(name="batch_size", description="calculated batch size") - ], - block_fn=input_block_fn, - description="A block that determines batch_size based on the number of prompts and num_images_per_prompt argument." -) -input_block = input_block_cls() -``` - -Now let's connect these blocks to create a pipeline: - -```py -from diffusers.modular_pipelines import SequentialPipelineBlocks, InsertableDict -# define a dict map block names to block class -blocks_dict = InsertableDict() -blocks_dict["input"] = input_block -blocks_dict["image_encoder"] = image_encoder_block -# create the multi-block -blocks = SequentialPipelineBlocks.from_blocks_dict(blocks_dict) -# convert it to a runnable pipeline -pipeline = blocks.init_pipeline() -``` - -Now you have a pipeline with 2 blocks. - -```py ->>> pipeline.blocks -SequentialPipelineBlocks( - Class: ModularPipelineBlocks - - Description: - - - Sub-Blocks: - [0] input (TestBlock) - Description: A block that determines batch_size based on the number of prompts and num_images_per_prompt argument. - - [1] image_encoder (TestBlock) - Description: Encode raw image into its latent presentation - -) -``` - -When you inspect `pipeline.doc`, you can see that `batch_size` is not listed as an input. The pipeline automatically detects that the `input_block` can produce `batch_size` for the `image_encoder_block`, so it doesn't ask the user to provide it. - -```py ->>> print(pipeline.doc) -class SequentialPipelineBlocks - - Inputs: - - prompt (`None`, *optional*): - - num_images_per_prompt (`None`, *optional*): - - image (`PIL.Image`, *optional*): - raw input image to process - - Outputs: - - batch_size (`None`): - - image_latents (`None`): - latents representing the image -``` - -At runtime, you have data flow like this: - -![Data Flow Diagram](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/Editor%20_%20Mermaid%20Chart-2025-06-30-092631.png) - -**How SequentialPipelineBlocks Works:** - -1. Blocks are executed in the order they're registered in the `blocks_dict` -2. Outputs from one block become available as intermediate inputs to all subsequent blocks -3. The pipeline automatically figures out which values need to be provided by the user and which will be generated by previous blocks -4. Each block maintains its own behavior and operates through its defined interface, while collectively these interfaces determine what the entire pipeline accepts and produces - -What happens within each block follows the same pattern we described earlier: each block gets its own `block_state` with the relevant inputs and intermediate inputs, performs its computation, and updates the pipeline state with its intermediate outputs. - -## `LoopSequentialPipelineBlocks` - -To create a loop in Modular Diffusers, you could use a single `PipelineBlock` like this: - -```python -class DenoiseLoop(PipelineBlock): - def __call__(self, components, state): - block_state = self.get_block_state(state) - for t in range(block_state.num_inference_steps): - # ... loop logic here - pass - self.set_block_state(state, block_state) - return components, state -``` - -Or you could create a `LoopSequentialPipelineBlocks`. The key difference is that with `LoopSequentialPipelineBlocks`, the loop itself is modular: you can add or remove blocks within the loop or reuse the same loop structure with different block combinations. - -It involves two parts: a **loop wrapper** and **loop blocks** - -* The **loop wrapper** (`LoopSequentialPipelineBlocks`) defines the loop structure, e.g. it defines the iteration variables, and loop configurations such as progress bar. - -* The **loop blocks** are basically standard pipeline blocks you add to the loop wrapper. - - they run sequentially for each iteration of the loop - - they receive the current iteration index as an additional parameter - - they share the same block_state throughout the entire loop - -Unlike regular `SequentialPipelineBlocks` where each block gets its own state, loop blocks share a single state that persists and evolves across iterations. - -We will build a simple loop block to demonstrate these concepts. Creating a loop block involves three steps: -1. defining the loop wrapper class -2. creating the loop blocks -3. adding the loop blocks to the loop wrapper class to create the loop wrapper instance - -**Step 1: Define the Loop Wrapper** - -To create a `LoopSequentialPipelineBlocks` class, you need to define: - -* `loop_inputs`: User input variables (equivalent to `PipelineBlock.inputs`) -* `loop_intermediate_inputs`: Intermediate variables needed from the mutable pipeline state (equivalent to `PipelineBlock.intermediates_inputs`) -* `loop_intermediate_outputs`: New intermediate variables this block will add to the mutable pipeline state (equivalent to `PipelineBlock.intermediates_outputs`) -* `__call__` method: Defines the loop structure and iteration logic - -Here is an example of a loop wrapper: - -```py -import torch -from diffusers.modular_pipelines import LoopSequentialPipelineBlocks, PipelineBlock, InputParam, OutputParam - -class LoopWrapper(LoopSequentialPipelineBlocks): - model_name = "test" - @property - def description(self): - return "I'm a loop!!" - @property - def loop_inputs(self): - return [InputParam(name="num_steps")] - @torch.no_grad() - def __call__(self, components, state): - block_state = self.get_block_state(state) - # Loop structure - can be customized to your needs - for i in range(block_state.num_steps): - # loop_step executes all registered blocks in sequence - components, block_state = self.loop_step(components, block_state, i=i) - self.set_block_state(state, block_state) - return components, state -``` - -**Step 2: Create Loop Blocks** - -Loop blocks are standard `PipelineBlock`s, but their `__call__` method works differently: -* It receives the iteration variable (e.g., `i`) passed by the loop wrapper -* It works directly with `block_state` instead of pipeline state -* No need to call `self.get_block_state()` or `self.set_block_state()` - -```py -class LoopBlock(PipelineBlock): - # this is used to identify the model family, we won't worry about it in this example - model_name = "test" - @property - def inputs(self): - return [InputParam(name="x")] - @property - def intermediate_outputs(self): - # outputs produced by this block - return [OutputParam(name="x")] - @property - def description(self): - return "I'm a block used inside the `LoopWrapper` class" - def __call__(self, components, block_state, i: int): - block_state.x += 1 - return components, block_state -``` - -**Step 3: Combine Everything** - -Finally, assemble your loop by adding the block(s) to the wrapper: - -```py -loop = LoopWrapper.from_blocks_dict({"block1": LoopBlock}) -``` - -Now you've created a loop with one step: - -```py ->>> loop -LoopWrapper( - Class: LoopSequentialPipelineBlocks - - Description: I'm a loop!! - - Sub-Blocks: - [0] block1 (LoopBlock) - Description: I'm a block used inside the `LoopWrapper` class - -) -``` - -It has two inputs: `x` (used at each step within the loop) and `num_steps` used to define the loop. - -```py ->>> print(loop.doc) -class LoopWrapper - - I'm a loop!! - - Inputs: - - x (`None`, *optional*): - - num_steps (`None`, *optional*): - - Outputs: - - x (`None`): -``` - -**Running the Loop:** - -```py -# run the loop -loop_pipeline = loop.init_pipeline() -x = loop_pipeline(num_steps=10, x=0, output="x") -assert x == 10 -``` - -**Adding Multiple Blocks:** - -We can add multiple blocks to run within each iteration. Let's run the loop block twice within each iteration: - -```py -loop = LoopWrapper.from_blocks_dict({"block1": LoopBlock(), "block2": LoopBlock}) -loop_pipeline = loop.init_pipeline() -x = loop_pipeline(num_steps=10, x=0, output="x") -assert x == 20 # Each iteration runs 2 blocks, so 10 iterations * 2 = 20 -``` - -**Key Differences from SequentialPipelineBlocks:** - -The main difference is that loop blocks share the same `block_state` across all iterations, allowing values to accumulate and evolve throughout the loop. Loop blocks could receive additional arguments (like the current iteration index) depending on the loop wrapper's implementation, since the wrapper defines how loop blocks are called. You can easily add, remove, or reorder blocks within the loop without changing the loop logic itself. - -The officially supported denoising loops in Modular Diffusers are implemented using `LoopSequentialPipelineBlocks`. You can explore the actual implementation to see how these concepts work in practice: - -```py -from diffusers.modular_pipelines.stable_diffusion_xl.denoise import StableDiffusionXLDenoiseStep -StableDiffusionXLDenoiseStep() -``` - -## `AutoPipelineBlocks` - -`AutoPipelineBlocks` allows you to pack different pipelines into one and automatically select which one to run at runtime based on the inputs. The main purpose is convenience and portability - for developers, you can package everything into one workflow, making it easier to share and use. - -For example, you might want to support text-to-image and image-to-image tasks. Instead of creating two separate pipelines, you can create an `AutoPipelineBlocks` that automatically chooses the workflow based on whether an `image` input is provided. - -Let's see an example. Here we'll create a dummy `AutoPipelineBlocks` that includes dummy text-to-image, image-to-image, and inpaint pipelines. - - -```py -from diffusers.modular_pipelines import AutoPipelineBlocks - -# These are dummy blocks and we only focus on "inputs" for our purpose -inputs = [InputParam(name="prompt")] -# block_fn prints out which workflow is running so we can see the execution order at runtime -block_fn = lambda x, y: print("running the text-to-image workflow") -block_t2i_cls = make_block(inputs=inputs, block_fn=block_fn, description="I'm a text-to-image workflow!") - -inputs = [InputParam(name="prompt"), InputParam(name="image")] -block_fn = lambda x, y: print("running the image-to-image workflow") -block_i2i_cls = make_block(inputs=inputs, block_fn=block_fn, description="I'm a image-to-image workflow!") - -inputs = [InputParam(name="prompt"), InputParam(name="image"), InputParam(name="mask")] -block_fn = lambda x, y: print("running the inpaint workflow") -block_inpaint_cls = make_block(inputs=inputs, block_fn=block_fn, description="I'm a inpaint workflow!") - -class AutoImageBlocks(AutoPipelineBlocks): - # List of sub-block classes to choose from - block_classes = [block_inpaint_cls, block_i2i_cls, block_t2i_cls] - # Names for each block in the same order - block_names = ["inpaint", "img2img", "text2img"] - # Trigger inputs that determine which block to run - # - "mask" triggers inpaint workflow - # - "image" triggers img2img workflow (but only if mask is not provided) - # - if none of above, runs the text2img workflow (default) - block_trigger_inputs = ["mask", "image", None] - # Description is extremely important for AutoPipelineBlocks - @property - def description(self): - return ( - "Pipeline generates images given different types of conditions!\n" - + "This is an auto pipeline block that works for text2img, img2img and inpainting tasks.\n" - + " - inpaint workflow is run when `mask` is provided.\n" - + " - img2img workflow is run when `image` is provided (but only when `mask` is not provided).\n" - + " - text2img workflow is run when neither `image` nor `mask` is provided.\n" - ) - -# Create the blocks -auto_blocks = AutoImageBlocks() -# convert to pipeline -auto_pipeline = auto_blocks.init_pipeline() -``` - -Now we have created an `AutoPipelineBlocks` that contains 3 sub-blocks. Notice the warning message at the top - this automatically appears in every `ModularPipelineBlocks` that contains `AutoPipelineBlocks` to remind end users that dynamic block selection happens at runtime. - -```py -AutoImageBlocks( - Class: AutoPipelineBlocks - - ==================================================================================================== - This pipeline contains blocks that are selected at runtime based on inputs. - Trigger Inputs: ['mask', 'image'] - ==================================================================================================== - - - Description: Pipeline generates images given different types of conditions! - This is an auto pipeline block that works for text2img, img2img and inpainting tasks. - - inpaint workflow is run when `mask` is provided. - - img2img workflow is run when `image` is provided (but only when `mask` is not provided). - - text2img workflow is run when neither `image` nor `mask` is provided. - - - - Sub-Blocks: - • inpaint [trigger: mask] (TestBlock) - Description: I'm a inpaint workflow! - - • img2img [trigger: image] (TestBlock) - Description: I'm a image-to-image workflow! - - • text2img [default] (TestBlock) - Description: I'm a text-to-image workflow! - -) -``` - -Check out the documentation with `print(auto_pipeline.doc)`: - -```py ->>> print(auto_pipeline.doc) -class AutoImageBlocks - - Pipeline generates images given different types of conditions! - This is an auto pipeline block that works for text2img, img2img and inpainting tasks. - - inpaint workflow is run when `mask` is provided. - - img2img workflow is run when `image` is provided (but only when `mask` is not provided). - - text2img workflow is run when neither `image` nor `mask` is provided. - - Inputs: - - prompt (`None`, *optional*): - - image (`None`, *optional*): - - mask (`None`, *optional*): -``` - -There is a fundamental trade-off of AutoPipelineBlocks: it trades clarity for convenience. While it is really easy for packaging multiple workflows, it can become confusing without proper documentation. e.g. if we just throw a pipeline at you and tell you that it contains 3 sub-blocks and takes 3 inputs `prompt`, `image` and `mask`, and ask you to run an image-to-image workflow: if you don't have any prior knowledge on how these pipelines work, you would be pretty clueless, right? - -This pipeline we just made though, has a docstring that shows all available inputs and workflows and explains how to use each with different inputs. So it's really helpful for users. For example, it's clear that you need to pass `image` to run img2img. This is why the description field is absolutely critical for AutoPipelineBlocks. We highly recommend you to explain the conditional logic very well for each `AutoPipelineBlocks` you would make. We also recommend to always test individual pipelines first before packaging them into AutoPipelineBlocks. - -Let's run this auto pipeline with different inputs to see if the conditional logic works as described. Remember that we have added `print` in each `PipelineBlock`'s `__call__` method to print out its workflow name, so it should be easy to tell which one is running: - -```py ->>> _ = auto_pipeline(image="image", mask="mask") -running the inpaint workflow ->>> _ = auto_pipeline(image="image") -running the image-to-image workflow ->>> _ = auto_pipeline(prompt="prompt") -running the text-to-image workflow ->>> _ = auto_pipeline(image="prompt", mask="mask") -running the inpaint workflow -``` - -However, even with documentation, it can become very confusing when AutoPipelineBlocks are combined with other blocks. The complexity grows quickly when you have nested AutoPipelineBlocks or use them as sub-blocks in larger pipelines. - -Let's make another `AutoPipelineBlocks` - this one only contains one block, and it does not include `None` in its `block_trigger_inputs` (which corresponds to the default block to run when none of the trigger inputs are provided). This means this block will be skipped if the trigger input (`ip_adapter_image`) is not provided at runtime. - -```py -from diffusers.modular_pipelines import SequentialPipelineBlocks, InsertableDict -inputs = [InputParam(name="ip_adapter_image")] -block_fn = lambda x, y: print("running the ip-adapter workflow") -block_ipa_cls = make_block(inputs=inputs, block_fn=block_fn, description="I'm a IP-adapter workflow!") - -class AutoIPAdapter(AutoPipelineBlocks): - block_classes = [block_ipa_cls] - block_names = ["ip-adapter"] - block_trigger_inputs = ["ip_adapter_image"] - @property - def description(self): - return "Run IP Adapter step if `ip_adapter_image` is provided." -``` - -Now let's combine these 2 auto blocks together into a `SequentialPipelineBlocks`: - -```py -auto_ipa_blocks = AutoIPAdapter() -blocks_dict = InsertableDict() -blocks_dict["ip-adapter"] = auto_ipa_blocks -blocks_dict["image-generation"] = auto_blocks -all_blocks = SequentialPipelineBlocks.from_blocks_dict(blocks_dict) -pipeline = all_blocks.init_pipeline() -``` - -Let's take a look: now things get more confusing. In this particular example, you could still try to explain the conditional logic in the `description` field here - there are only 4 possible execution paths so it's doable. However, since this is a `SequentialPipelineBlocks` that could contain many more blocks, the complexity can quickly get out of hand as the number of blocks increases. - -```py ->>> all_blocks -SequentialPipelineBlocks( - Class: ModularPipelineBlocks - - ==================================================================================================== - This pipeline contains blocks that are selected at runtime based on inputs. - Trigger Inputs: ['image', 'mask', 'ip_adapter_image'] - Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('image')`). - ==================================================================================================== - - - Description: - - - Sub-Blocks: - [0] ip-adapter (AutoIPAdapter) - Description: Run IP Adapter step if `ip_adapter_image` is provided. - - - [1] image-generation (AutoImageBlocks) - Description: Pipeline generates images given different types of conditions! - This is an auto pipeline block that works for text2img, img2img and inpainting tasks. - - inpaint workflow is run when `mask` is provided. - - img2img workflow is run when `image` is provided (but only when `mask` is not provided). - - text2img workflow is run when neither `image` nor `mask` is provided. - - -) - -``` - -This is when the `get_execution_blocks()` method comes in handy - it basically extracts a `SequentialPipelineBlocks` that only contains the blocks that are actually run based on your inputs. - -Let's try some examples: - -`mask`: we expect it to skip the first ip-adapter since `ip_adapter_image` is not provided, and then run the inpaint for the second block. - -```py ->>> all_blocks.get_execution_blocks('mask') -SequentialPipelineBlocks( - Class: ModularPipelineBlocks - - Description: - - - Sub-Blocks: - [0] image-generation (TestBlock) - Description: I'm a inpaint workflow! - -) -``` - -Let's also actually run the pipeline to confirm: - -```py ->>> _ = pipeline(mask="mask") -skipping auto block: AutoIPAdapter -running the inpaint workflow -``` - -Try a few more: - -```py -print(f"inputs: ip_adapter_image:") -blocks_select = all_blocks.get_execution_blocks('ip_adapter_image') -print(f"expected_execution_blocks: {blocks_select}") -print(f"actual execution blocks:") -_ = pipeline(ip_adapter_image="ip_adapter_image", prompt="prompt") -# expect to see ip-adapter + text2img - -print(f"inputs: image:") -blocks_select = all_blocks.get_execution_blocks('image') -print(f"expected_execution_blocks: {blocks_select}") -print(f"actual execution blocks:") -_ = pipeline(image="image", prompt="prompt") -# expect to see img2img - -print(f"inputs: prompt:") -blocks_select = all_blocks.get_execution_blocks('prompt') -print(f"expected_execution_blocks: {blocks_select}") -print(f"actual execution blocks:") -_ = pipeline(prompt="prompt") -# expect to see text2img (prompt is not a trigger input so fallback to default) - -print(f"inputs: mask + ip_adapter_image:") -blocks_select = all_blocks.get_execution_blocks('mask','ip_adapter_image') -print(f"expected_execution_blocks: {blocks_select}") -print(f"actual execution blocks:") -_ = pipeline(mask="mask", ip_adapter_image="ip_adapter_image") -# expect to see ip-adapter + inpaint -``` - -In summary, `AutoPipelineBlocks` is a good tool for packaging multiple workflows into a single, convenient interface and it can greatly simplify the user experience. However, always provide clear descriptions explaining the conditional logic, test individual pipelines first before combining them, and use `get_execution_blocks()` to understand runtime behavior in complex compositions. \ No newline at end of file From 2104bef01d7a948a3982c8ec6038c61ee04da5dc Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 9 Jul 2025 23:04:54 +0200 Subject: [PATCH 168/170] update more modular pipeline doc --- .../en/modular_diffusers/modular_pipeline.md | 191 ++++++++++-------- docs/source/en/modular_diffusers/overview.md | 16 +- 2 files changed, 116 insertions(+), 91 deletions(-) diff --git a/docs/source/en/modular_diffusers/modular_pipeline.md b/docs/source/en/modular_diffusers/modular_pipeline.md index c4d82306d127..55182b921fdb 100644 --- a/docs/source/en/modular_diffusers/modular_pipeline.md +++ b/docs/source/en/modular_diffusers/modular_pipeline.md @@ -36,17 +36,15 @@ For information on how data flows through pipelines, see the [PipelineState and ## Create ModularPipelineBlocks -Pipeline blocks are the fundamental building blocks of the Modular Diffusers system. All pipeline blocks inherit from the base class `ModularPipelineBlocks`, including: +In Modular Diffusers system, you build pipelines using Pipeline blocks. Pipeline Blocks are fundamental building blocks - they define what components, inputs/outputs, and computation logics are needed. They are designed to be assembled into workflows for tasks such as image generation, video creation, and inpainting. But they are just definitions and don't actually run anything. To execute blocks, you need to put them into a `ModularPipeline`. We'll first learn how to create predefined blocks here before talking about how to run them using `ModularPipeline`. + +All pipeline blocks inherit from the base class `ModularPipelineBlocks`, including: - [`PipelineBlock`]: The most granular block - you define the input/output/components requirements and computation logic. - [`SequentialPipelineBlocks`]: A multi-block composed of multiple blocks that run sequentially, passing outputs as inputs to the next block. - [`LoopSequentialPipelineBlocks`]: A special type of `SequentialPipelineBlocks` that runs the same sequence of blocks multiple times (loops), typically used for iterative processes like denoising steps in diffusion models. - [`AutoPipelineBlocks`]: A multi-block composed of multiple blocks that are selected at runtime based on the inputs. -All blocks have a consistent interface defining their requirements (components, configs, inputs, outputs) and computation logic. They are designed to be assembled into workflows for tasks such as image generation, video creation, and inpainting. However, blocks aren't runnable on thier own and they need to be converted into a a ModularPipeline to actually run. - -**Blocks vs Pipelines**: Blocks are just definitions - they define what components, inputs/outputs, and computation logics are needed, but they don't actually run anything. To execute blocks, you need to put them into a `ModularPipeline`. We will first learn how to create predefined blocks here before talking about how to run them using `ModularPipeline`. - It is very easy to use a `ModularPipelineBlocks` officially supported in 🧨 Diffusers ```py @@ -77,7 +75,7 @@ StableDiffusionXLTextEncoderStep( ) ``` -More commonly, you need multiple blocks to build your workflow. You can create a `SequentialPipelineBlocks` using block class presets from 🧨 Diffusers. `TEXT2IMAGE_BLOCKS` is a preset containing all the blocks needed for text-to-image generation. +More commonly, you need multiple blocks to build your workflow. You can create a `SequentialPipelineBlocks` using block class presets from 🧨 Diffusers. `TEXT2IMAGE_BLOCKS` is a dict containing all the blocks needed for text-to-image generation. ```py from diffusers.modular_pipelines import SequentialPipelineBlocks @@ -85,7 +83,7 @@ from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS) ``` -This creates a `SequentialPipelineBlocks`, which is a multi-block composed of other blocks. Unlike single blocks (like the `text_encoder_block` we saw earlier), this multi-block has a `sub_blocks` attribute that contains the sub-blocks (text_encoder, input, set_timesteps, prepare_latents, prepare_added_con, denoise, decode). Its requirements for components, inputs, and intermediate inputs are combined from these blocks that compose it. At runtime, it executes its sub-blocks sequentially and passes the pipeline state from one block to another. +This creates a `SequentialPipelineBlocks`. Unlike the `text_encoder_block` we saw earlier, this is a multi-block and its `sub_blocks` attribute contains a list of other blocks (text_encoder, input, set_timesteps, prepare_latents, prepare_added_con, denoise, decode). Its requirements for components, inputs, and intermediate inputs are combined from these blocks that compose it. At runtime, it executes its sub-blocks sequentially and passes the pipeline state from one block to another. ```py >>> t2i_blocks @@ -146,7 +144,7 @@ SequentialPipelineBlocks( ) ``` -The block classes preset (`TEXT2IMAGE_BLOCKS`) we used is just a dictionary that maps names to ModularPipelineBlocks classes +This is the block classes preset (`TEXT2IMAGE_BLOCKS`) we used: It is just a dictionary that maps names to ModularPipelineBlocks classes ```py >>> TEXT2IMAGE_BLOCKS @@ -210,7 +208,9 @@ Let's make a new block classes preset by insert IP-Adapter at index 0 (before th ```py from diffusers.modular_pipelines.stable_diffusion_xl import StableDiffusionXLAutoIPAdapterStep CUSTOM_BLOCKS = TEXT2IMAGE_BLOCKS.copy() +# CUSTOM_BLOCKS is now a preset including ip_adapter CUSTOM_BLOCKS.insert("ip_adapter", StableDiffusionXLAutoIPAdapterStep, 0) +# create a blocks isntance from the preset custom_blocks = SequentialPipelineBlocks.from_blocks_dict(CUSTOM_BLOCKS) ``` @@ -300,12 +300,16 @@ ALL_BLOCKS = { -This covers the essentials of pipeline blocks! You may have noticed that we haven't discussed how to load or run pipeline blocks - that's because **pipeline blocks are not runnable by themselves**. They are essentially **"definitions"** - they define the specifications and computational steps for a pipeline, but they do not contain any model states. To actually run them, you need to convert them into a `ModularPipeline` object. +This covers the essentials of pipeline blocks! Like we have already mentioned, **pipeline blocks are not runnable by themselves**. They are essentially **"definitions"** - they define the specifications and computational steps for a pipeline, but they do not contain any model states. To actually run them, you need to convert them into a `ModularPipeline` object. ## Modular Repo -`ModularPipeline` only works with modular repositories. You can find an example modular repo [here](https://huggingface.co/YiYiXu/modular-diffdiff). +To convert blocks into a runnable pipeline, you may need a repository if your blocks contain **pretrained components** (models with checkpoints that need to be loaded from the Hub). Pipeline blocks define what components they need (like a UNet, text encoder, etc.), as well as how to create them: components can be either created using **from_pretrained** method (with checkpoints) or **from_config** (initialized from scratch with default configuration, usually stateless like a guider or scheduler). + +If your pipeline contains **pretrained components**, you typically need to use a repository to provide the loading specifications and metadata. + +`ModularPipeline` works specifically with modular repositories, which offer more flexibility in component loading compared to traditional repositories. You can find an example modular repo [here](https://huggingface.co/YiYiXu/modular-diffdiff). A `DiffusionPipeline` defines `model_index.json` to configure its components. However, repositories for Modular Diffusers work with `modular_model_index.json`. Let's walk through the differences here. @@ -346,56 +350,57 @@ Unlike standard repositories where components must be in subfolders within the s Each `ModularPipelineBlocks` has an `init_pipeline` method that can initialize a `ModularPipeline` object based on its component and configuration specifications. -Let's convert our `t2i_blocks` (which we created earlier) into a runnable `ModularPipeline`: +Let's convert our `t2i_blocks` (which we created earlier) into a runnable `ModularPipeline`. We'll use a `ComponentsManager` to handle device placement, memory management, and component reuse automatically: ```py # We already have this from earlier t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS) # Now convert it to a ModularPipeline +from diffusers import ComponentsManager modular_repo_id = "YiYiXu/modular-loader-t2i-0704" -t2i_pipeline = t2i_blocks.init_pipeline(modular_repo_id) +components = ComponentsManager() +t2i_pipeline = t2i_blocks.init_pipeline(modular_repo_id, components_manager=components) ``` -The `init_pipeline()` method creates a ModularPipeline and loads component specifications from the repository's `modular_model_index.json` file, but doesn't load the actual models yet. - -💡 We recommend using `ModularPipeline` with Component Manager by passing a `components_manager`: - -```py ->>> components = ComponentsManager() ->>> pipeline = blocks.init_pipeline(modular_repo_id, components_manager=components) -``` - -This helps you to: -1. Detect and manage duplicated models (warns when trying to register an existing model) -2. Easily reuse components across different pipelines -3. Apply offloading strategies across multiple pipelines - -You can read more about [Components Manager](./components_manager.md) +💡 **ComponentsManager** is the model registry and management system in diffusers, it track all the models in one place and let you add, remove and reuse them across different workflows in most efficient way. Without it, you'd need to manually manage GPU memory, device placement, and component sharing between workflows. See the [Components Manager guide](components_manager.md) for detailed information. +The `init_pipeline()` method creates a ModularPipeline and loads component specifications from the repository's `modular_model_index.json` file, but doesn't load the actual models yet. + ## Creating a `ModularPipeline` with `from_pretrained` You can create a `ModularPipeline` from a HuggingFace Hub repository with `from_pretrained` method, as long as it's a modular repo: ```py -from diffusers import ModularPipeline -pipeline = ModularPipeline.from_pretrained( "YiYiXu/modular-loader-t2i-0704") +from diffusers import ModularPipeline, ComponentsManager +components = ComponentsManager() +pipeline = ModularPipeline.from_pretrained("YiYiXu/modular-loader-t2i-0704", components_manager=components) ``` Loading custom code is also supported: ```py -from diffusers import ModularPipeline +from diffusers import ModularPipeline, ComponentsManager +components = ComponentsManager() modular_repo_id = "YiYiXu/modular-diffdiff-0704" -diffdiff_pipeline = ModularPipeline.from_pretrained(modular_repo_id, trust_remote_code=True) +diffdiff_pipeline = ModularPipeline.from_pretrained(modular_repo_id, trust_remote_code=True, components_manager=components) +``` + +This modular repository contains custom code. The folder contains these files: + +``` +modular-diffdiff-0704/ +├── block.py # Custom pipeline blocks implementation +├── config.json # Pipeline configuration and auto_map +└── modular_model_index.json # Component loading specifications ``` -This modular repository contains custom code. The [`config.json`](https://huggingface.co/YiYiXu/modular-diffdiff-0704/blob/main/config.json) file defines a custom `DiffDiffBlocks` class and points to its implementation: +The [`config.json`](https://huggingface.co/YiYiXu/modular-diffdiff-0704/blob/main/config.json) file defines a custom `DiffDiffBlocks` class and points to its implementation: ```json { @@ -539,14 +544,6 @@ StableDiffusionXLModularPipeline { You can see all the **pretrained components** that will be loaded using `from_pretrained` method are listed as entries. Each entry contains 3 elements: `(library, class, loading_specs_dict)`: - - -**Pretrained vs Config-based Components**: Only pretrained components (like models loaded from Hugging Face Hub) appear in the `modular_model_index.json` file at all. Components created with default configurations at initialization (like schedulers, guiders) are not included in the index since they don't need to be loaded from external sources. - -Whether a component is pretrained or config-based is defined in each pipeline block's `expected_components` field using `ComponentSpec` with the `default_creation_method` parameter. See the [PipelineBlock](./pipeline_block.md) guide for more details on how to define component specifications. - - - - **`library` and `class`**: Show the actual loaded component info. If `null`, the component is not loaded yet. - **`loading_specs_dict`**: Contains all the information needed to load the component (repo, subfolder, variant, etc.) @@ -578,9 +575,11 @@ There are also a few properties that can provide a quick summary of component lo ['guider', 'image_processor'] ``` +From config components (like `guider` and `image_processor`) are not included in the pipeline output above because they don't need loading specs - they're already initialized during pipeline creation. You can see this because they're not listed in `null_component_names`. + ## Modifying Loading Specs -When you call `pipeline.load_components(names=)` or `pipeline.load_default_components()`, it uses the loading specs from the modular repository's `modular_model_index.json`. You can change where components are loaded from by default by modifying the `modular_model_index.json` in the repository. Just find the file on the Hub and click edit - you can change any field in the loading specs: `repo`, `subfolder`, `variant`, `revision`, etc. +When you call `pipeline.load_components(names=)` or `pipeline.load_default_components()`, it uses the loading specs from the modular repository's `modular_model_index.json`. You can change where components are loaded from by modifying the `modular_model_index.json` in the repository. Just find the file on the Hub and click edit - you can change any field in the loading specs: `repo`, `subfolder`, `variant`, `revision`, etc. ```py # Original spec in modular_model_index.json @@ -604,7 +603,12 @@ When you call `pipeline.load_components(names=)` or `pipeline.load_default_compo ] ``` -When you call `pipeline.load_components(...)`/`pipeline.load_default_components()`, it will now load from the new repository by default. +Now if you create a pipeline using the same blocks and updated repository, it will by default load from the new repository. + +```py +pipeline = ModularPipeline.from_pretrained("YiYiXu/modular-loader-t2i-0704", components_manager=components) +pipeline.load_components(names="unet") +``` ## Updating components in a `ModularPipeline` @@ -631,7 +635,7 @@ from diffusers import UNet2DConditionModel import torch unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", variant="fp16", torch_dtype=torch.float16) ``` -You should do +You should load your model like this ```py from diffusers import ComponentSpec, UNet2DConditionModel @@ -639,13 +643,15 @@ unet_spec = ComponentSpec(name="unet",type_hint=UNet2DConditionModel, repo="stab unet2 = unet_spec.load(torch_dtype=torch.float16) ``` -The key difference is that the second unet (the one we load with `ComponentSpec`) retains its loading specs, so you can extract and recreate it: +The key difference is that the second unet retains its loading specs, so you can extract the spec and recreate the unet: ```py -# to extract spec, you can do spec.load() to recreate it +# component -> spec >>> spec = ComponentSpec.from_component("unet", unet2) >>> spec ComponentSpec(name='unet', type_hint=, description=None, config=None, repo='stabilityai/stable-diffusion-xl-base-1.0', subfolder='unet', variant='fp16', revision=None, default_creation_method='from_pretrained') +# spec -> component +>>> unet2_recreatd = spec.load(torch_dtype=torch.float16) ``` To replace the unet in the pipeline @@ -654,7 +660,7 @@ To replace the unet in the pipeline t2i_pipeline.update_components(unet=unet2) ``` -Not only is the `unet` component swapped, but its loading specs are also updated from "RunDiffusion/Juggernaut-XL-v9" to "stabilityai/stable-diffusion-xl-base-1.0". This means that if you save the pipeline now and load it back with `from_pretrained`, the new pipeline will by default load the SDXL original unet. +Not only is the `unet` component swapped, but its loading specs are also updated from "RunDiffusion/Juggernaut-XL-v9" to "stabilityai/stable-diffusion-xl-base-1.0" in pipeline config. This means that if you save the pipeline now and load it back with `from_pretrained`, the new pipeline will by default load the SDXL original unet. ``` >>> t2i_pipeline @@ -739,6 +745,9 @@ ClassifierFreeGuidance { To change parameters of the same guider type (e.g., adjusting the `guidance_scale` for CFG), you have two options: **Option 1: Use ComponentSpec.create() method** + +You just need to pass the parameter with the new value to override the default one. + ```python >>> guider_spec = t2i_pipeline.get_component_spec("guider") >>> guider = guider_spec.create(guidance_scale=10) @@ -746,6 +755,9 @@ To change parameters of the same guider type (e.g., adjusting the `guidance_scal ``` **Option 2: Pass ComponentSpec directly** + +Update the spec directly and pass it to `update_components()`. + ```python >>> guider_spec = t2i_pipeline.get_component_spec("guider") >>> guider_spec.config["guidance_scale"] = 10 @@ -787,7 +799,6 @@ ModularPipeline.update_components: adding guider with new type: PerturbedAttenti -💡 **Component Loading Methods**: - For `from_config` components (like guiders, schedulers): You can pass an object of required type OR pass a ComponentSpec directly (which calls `create()` under the hood) - For `from_pretrained` components (like models): You must use ComponentSpec to ensure proper tagging and loading @@ -831,50 +842,25 @@ The component spec has also been updated to reflect the new guider type: ComponentSpec(name='guider', type_hint=, description=None, config=FrozenDict([('guidance_scale', 5.0), ('perturbed_guidance_scale', 2.5), ('perturbed_guidance_start', 0.01), ('perturbed_guidance_stop', 0.2), ('perturbed_guidance_layers', None), ('perturbed_guidance_config', LayerSkipConfig(indices=[2, 9], fqn='mid_block.attentions.0.transformer_blocks', skip_attention=False, skip_attention_scores=True, skip_ff=False, dropout=1.0)), ('guidance_rescale', 0.0), ('use_original_formulation', False), ('start', 0.0), ('stop', 1.0), ('_use_default_values', ['perturbed_guidance_start', 'use_original_formulation', 'perturbed_guidance_layers', 'stop', 'start', 'guidance_rescale', 'perturbed_guidance_stop']), ('_class_name', 'PerturbedAttentionGuidance'), ('_diffusers_version', '0.35.0.dev0')]), repo=None, subfolder=None, variant=None, revision=None, default_creation_method='from_config') ``` -However, the "guider" is still not included in the pipeline config and will not be saved into the `modular_model_index.json` since it remains a `from_config` component: +The "guider" is still a `from_config` component: is still not included in the pipeline config and will not be saved into the `modular_model_index.json`. ```py >>> assert "guider" not in t2i_pipeline.config ``` -#### Upload Custom Guider to Hub for Easy Loading & Sharing - -You can upload your customized guider to the Hub so that it can be loaded more easily: - -```py -guider.push_to_hub("YiYiXu/modular-loader-t2i-guider", subfolder="pag_guider") -``` - -Voilà! Now you have a subfolder called `pag_guider` on that repository. Let's change our guider_spec to use `from_pretrained` as the default creation method and update the loading spec to use this subfolder we just created: - -```python -guider_spec = t2i_pipeline.get_component_spec("guider") -guider_spec.default_creation_method="from_pretrained" -guider_spec.repo="YiYiXu/modular-loader-t2i-guider" -guider_spec.subfolder="pag_guider" -pag_guider = guider_spec.load() -t2i_pipeline.update_components(guider=pag_guider) -``` - -You will get a warning about changing the creation method: - -``` -ModularPipeline.update_components: changing the default_creation_method of guider from from_config to from_pretrained. -``` +However, you can change it to a `from_pretrained` component, which allows you to upload your customized guider to the Hub and load it into your pipeline. -Now not only the `guider` component and its component_spec are updated, but so is the pipeline config. Let's push it to a new repository: +#### Loading Custom Guiders from Hub -```py -t2i_pipeline.push_to_hub("YiYiXu/modular-doc-guider") -``` +If you already have a guider saved on the Hub and a `modular_model_index.json` with the loading spec for that guider, it will automatically be changed to a `from_pretrained` component during pipeline initialization. -If you check the `modular_model_index.json`, you'll see the guider is now included: +For example, this `modular_model_index.json` includes loading specs for the guider: ```json { "guider": [ - "diffusers", - "PerturbedAttentionGuidance", + null, + null, { "repo": "YiYiXu/modular-loader-t2i-guider", "revision": null, @@ -889,16 +875,55 @@ If you check the `modular_model_index.json`, you'll see the guider is now includ } ``` -Now when you create the pipeline from that repo directly, the `guider` is not automatically loaded anymore (since it's now a `from_pretrained` component), but when you run `load_default_components()`, the PAG guider will be loaded by default: +When you use this repository to create a pipeline with the same blocks (that originally configured guider as a `from_config` component), the guider becomes a `from_pretrained` component. This means it doesn't get created during initialization, and after you call `load_default_components()`, it loads based on the spec - resulting in the PAG guider instead of the default CFG. ```py t2i_pipeline = t2i_blocks.init_pipeline("YiYiXu/modular-doc-guider") -assert t2i_pipeline.guider is None +assert t2i_pipeline.guider is None # Not created during init t2i_pipeline.load_default_components() -t2i_pipeline.guider +t2i_pipeline.guider # Now loaded as PAG guider +``` + +#### Upload Custom Guider to Hub for Easy Loading & Sharing + +Now let's see how we can share the guider on the Hub and change it to a `from_pretrained` component. + +```py +guider.push_to_hub("YiYiXu/modular-loader-t2i-guider", subfolder="pag_guider") +``` + +Voilà! Now you have a subfolder called `pag_guider` on that repository. + +You have a few options to make this guider available in your pipeline: + +1. **Directly modify the `modular_model_index.json`** to add a loading spec for the guider by pointing to a folder containing the desired guider config. + +2. **Use the `update_components` method** to change it to a `from_pretrained` component for your pipeline. This is easier if you just want to try it out with different repositories. + +Let's use the second approach and change our guider_spec to use `from_pretrained` as the default creation method and update the loading spec to use this subfolder we just created: + +```python +guider_spec = t2i_pipeline.get_component_spec("guider") +guider_spec.default_creation_method="from_pretrained" +guider_spec.repo="YiYiXu/modular-loader-t2i-guider" +guider_spec.subfolder="pag_guider" +pag_guider = guider_spec.load() +t2i_pipeline.update_components(guider=pag_guider) ``` -Of course, you can also directly modify the `modular_model_index.json` to add a loading spec for the guider by pointing to a folder containing the desired guider config. +You will get a warning about changing the creation method: + +``` +ModularPipeline.update_components: changing the default_creation_method of guider from from_config to from_pretrained. +``` + +Now not only the `guider` component and its component_spec are updated, but so is the pipeline config. + +If you want to change the default behavior for future pipelines, you can push the updated pipeline to the Hub. This way, when others use your repository, they'll get the PAG guider by default. However, this is optional - you don't have to do this if you just want to experiment locally. + +```py +t2i_pipeline.push_to_hub("YiYiXu/modular-doc-guider") +``` diff --git a/docs/source/en/modular_diffusers/overview.md b/docs/source/en/modular_diffusers/overview.md index 359fe5823dae..04194cc02163 100644 --- a/docs/source/en/modular_diffusers/overview.md +++ b/docs/source/en/modular_diffusers/overview.md @@ -28,15 +28,15 @@ With Modular Diffusers, we introduce a unified pipeline system that simplifies h Here's how our guides are organized to help you navigate the Modular Diffusers documentation: ### 🚀 Running Pipelines -- **[modular_pipeline.md](./modular_pipeline.md)** - How to use predefined blocks to build a pipeline and run it -- **[components_manager.md](./components_manager.md)** - How to manage and reuse components across multiple pipelines +- **[Modular Pipeline Guide](./modular_pipeline.md)** - How to use predefined blocks to build a pipeline and run it +- **[Components Manager Guide](./components_manager.md)** - How to manage and reuse components across multiple pipelines ### 📚 Creating PipelineBlocks -- **[modular_diffusers_states.md](./modular_diffusers_states.md)** - Understanding PipelineState and BlockState -- **[pipeline_block.md](./pipeline_block.md)** - How to write custom PipelineBlocks -- **[sequential_pipeline_blocks.md](sequential_pipeline_blocks.md)** - Connecting blocks in sequence -- **[loop_sequential_pipeline_blocks.md](./loop_sequential_pipeline_blocks.md)** - Creating iterative workflows -- **[auto_pipeline_blocks.md](./auto_pipeline_blocks.md)** - Conditional block selection +- **[Pipeline and Block States](./modular_diffusers_states.md)** - Understanding PipelineState and BlockState +- **[Pipeline Block](./pipeline_block.md)** - How to write custom PipelineBlocks +- **[SequentialPipelineBlocks](sequential_pipeline_blocks.md)** - Connecting blocks in sequence +- **[LoopSequentialPipelineBlocks](./loop_sequential_pipeline_blocks.md)** - Creating iterative workflows +- **[AutoPipelineBlocks](./auto_pipeline_blocks.md)** - Conditional block selection ### 🎯 Practical Examples -- **[end_to_end_guide.md](./end_to_end_guide.md)** - Complete end-to-end examples and practical workflows +- **[end_to_end_guide.md](./end_to_end_guide.md)** - Complete end-to-end examples including sharing your workflow in huggingface hub and deplying UI nodes From 65ba8928cca68fe52cce5563e0e6b71b5edce1a5 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 10 Jul 2025 01:10:59 +0200 Subject: [PATCH 169/170] update doc --- .../modular_diffusers/components_manager.md | 8 ++--- .../en/modular_diffusers/end_to_end_guide.md | 10 +++--- .../loop_sequential_pipeline_blocks.md | 1 + docs/source/en/modular_diffusers/overview.md | 2 +- .../en/modular_diffusers/pipeline_block.md | 31 +++++++++++++------ 5 files changed, 33 insertions(+), 19 deletions(-) diff --git a/docs/source/en/modular_diffusers/components_manager.md b/docs/source/en/modular_diffusers/components_manager.md index 222944e83703..15b6c66b9b06 100644 --- a/docs/source/en/modular_diffusers/components_manager.md +++ b/docs/source/en/modular_diffusers/components_manager.md @@ -144,9 +144,9 @@ Components: ====================================================================================================================================================================================================== Models: ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ -Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection +Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ -text_encoder_139918506246832 | CLIPTextModel | cpu | torch.float32 | 0.46 | stabilityai/stable-diffusion-xl-base-1.0|text_encoder|null|null | N/A +text_encoder_139918506246832 | CLIPTextModel | cpu | torch.float32 | 0.46 | stabilityai/stable-diffusion-xl-base-1.0|text_encoder|null|null | N/A text_encoder_duplicated_139917580682672 | CLIPTextModel | cpu | torch.float32 | 0.46 | stabilityai/stable-diffusion-xl-base-1.0|text_encoder|null|null | N/A ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ @@ -282,7 +282,7 @@ As mentioned earlier, `ModularPipeline` has a property `null_component_names` th The warnings that follow are expected and indicate that the Components Manager is correctly identifying that these components already exist and will be reused rather than creating duplicates: -``` +```out ComponentsManager: component 'text_encoder' already exists as 'text_encoder_139917586016400' ComponentsManager: component 'text_encoder_2' already exists as 'text_encoder_2_139917699973424' ComponentsManager: component 'tokenizer' already exists as 'tokenizer_139917580599504' @@ -293,7 +293,7 @@ ComponentsManager: component 'vae' already exists as 'vae_139917722459040' ComponentsManager: component 'scheduler' already exists as 'scheduler_139916266559408' ComponentsManager: component 'controlnet' already exists as 'controlnet_139917722454432' ``` -``` + The pipeline is now fully loaded: diff --git a/docs/source/en/modular_diffusers/end_to_end_guide.md b/docs/source/en/modular_diffusers/end_to_end_guide.md index 6a9e4dc31303..cb7b87552a37 100644 --- a/docs/source/en/modular_diffusers/end_to_end_guide.md +++ b/docs/source/en/modular_diffusers/end_to_end_guide.md @@ -361,9 +361,9 @@ Run the example now, you should see an apple with its right half transformed int ## Adding IP-adapter -We provide an auto IP-adapter block that you can plug-and-play into your modular workflow. It's an `AutoPipelineBlocks`, so it will only run when the user passes an IP adapter image. In this tutorial, we'll focus on how to package it into your differential diffusion workflow. To learn more about `AutoPipelineBlocks`, see [here](https://huggingface.co/docs/diffusers/modular_diffusers/write_own_pipeline_block#autopipelineblocks) +We provide an auto IP-adapter block that you can plug-and-play into your modular workflow. It's an `AutoPipelineBlocks`, so it will only run when the user passes an IP adapter image. In this tutorial, we'll focus on how to package it into your differential diffusion workflow. To learn more about `AutoPipelineBlocks`, see [here](./auto_pipeline_blocks.md) -We talked about how to add IP-adapter into your workflow in the [getting-started guide](https://huggingface.co/docs/diffusers/modular_diffusers/quicktour#ip-adapter). Let's just go ahead to create the IP-adapter block. +We talked about how to add IP-adapter into your workflow in the [Modular Pipeline Guide](./modular_pipeline.md). Let's just go ahead to create the IP-adapter block. ```py >>> from diffusers.modular_pipelines.stable_diffusion_xl.encoders import StableDiffusionXLAutoIPAdapterStep @@ -496,7 +496,7 @@ From looking at the code workflow: differential diffusion only modifies the "bef Intuitively, these two techniques are orthogonal and should combine naturally: differential diffusion controls how much the inference process can deviate from the original in each region, while ControlNet controls in what direction that change occurs. -With this understanding, let's assemble the `SDXLDiffDiffControlNetDenoiseStep`: +With this understanding, let's assemble the diffdiff-controlnet loop by combining the diffdiff before-denoiser step and controlnet denoiser step. ```py >>> class SDXLDiffDiffControlNetDenoiseStep(StableDiffusionXLDenoiseLoopWrapper): @@ -617,7 +617,7 @@ to use ``` ## Creating a Modular Repo -You can easily share your differential diffusion workflow on the hub, by creating a modular repo like this https://huggingface.co/YiYiXu/modular-diffdiff +You can easily share your differential diffusion workflow on the Hub by creating a modular repo. This is one created using the code we just wrote together: https://huggingface.co/YiYiXu/modular-diffdiff To create a Modular Repo and share on hub, you just need to run `save_pretrained()` along with the `push_to_hub=True` flag. Note that if your pipeline contains custom block, you need to manually upload the code to the hub. But we are working on a command line tool to help you upload it very easily. @@ -641,7 +641,7 @@ With a modular repo, it is very easy for the community to use the workflow you j >>> components.enable_auto_cpu_offload() ``` -see more usage example on model card +see more usage example on model card. ## deploy a mellon node diff --git a/docs/source/en/modular_diffusers/loop_sequential_pipeline_blocks.md b/docs/source/en/modular_diffusers/loop_sequential_pipeline_blocks.md index e97a133d221a..e95cdc7163b4 100644 --- a/docs/source/en/modular_diffusers/loop_sequential_pipeline_blocks.md +++ b/docs/source/en/modular_diffusers/loop_sequential_pipeline_blocks.md @@ -39,6 +39,7 @@ class DenoiseLoop(PipelineBlock): pass self.set_block_state(state, block_state) return components, state +``` But in this tutorial, we will focus on how to use `LoopSequentialPipelineBlocks` to create a "composable" denoising loop where you can add or remove blocks within the loop or reuse the same loop structure with different block combinations. diff --git a/docs/source/en/modular_diffusers/overview.md b/docs/source/en/modular_diffusers/overview.md index 04194cc02163..9702cea0633d 100644 --- a/docs/source/en/modular_diffusers/overview.md +++ b/docs/source/en/modular_diffusers/overview.md @@ -39,4 +39,4 @@ Here's how our guides are organized to help you navigate the Modular Diffusers d - **[AutoPipelineBlocks](./auto_pipeline_blocks.md)** - Conditional block selection ### 🎯 Practical Examples -- **[end_to_end_guide.md](./end_to_end_guide.md)** - Complete end-to-end examples including sharing your workflow in huggingface hub and deplying UI nodes +- **[End-to-End Example](./end_to_end_guide.md)** - Complete end-to-end examples including sharing your workflow in huggingface hub and deplying UI nodes diff --git a/docs/source/en/modular_diffusers/pipeline_block.md b/docs/source/en/modular_diffusers/pipeline_block.md index 20f46e928c28..17a819732fd0 100644 --- a/docs/source/en/modular_diffusers/pipeline_block.md +++ b/docs/source/en/modular_diffusers/pipeline_block.md @@ -21,7 +21,9 @@ specific language governing permissions and limitations under the License. In Modular Diffusers, you build your workflow using `ModularPipelineBlocks`. We support 4 different types of blocks: `PipelineBlock`, `SequentialPipelineBlocks`, `LoopSequentialPipelineBlocks`, and `AutoPipelineBlocks`. Among them, `PipelineBlock` is the most fundamental building block of the whole system - it's like a brick in a Lego system. These blocks are designed to easily connect with each other, allowing for modular construction of creative and potentially very complex workflows. + **Important**: `PipelineBlock`s are definitions/specifications, not runnable pipelines. They define what a block should do and what data it needs, but you need to convert them into a `ModularPipeline` to actually execute them. For information on creating and running pipelines, see the [Modular Pipeline guide](./modular_pipeline.md). + In this tutorial, we will focus on how to write a basic `PipelineBlock` and how it interacts with the pipeline state. @@ -30,14 +32,14 @@ In this tutorial, we will focus on how to write a basic `PipelineBlock` and how Before we dive into creating `PipelineBlock`s, make sure you have a basic understanding of `PipelineState`. It acts as the global state container that all blocks operate on - each block gets a local view (`BlockState`) of the relevant variables it needs from `PipelineState`, performs its operations, and then updates `PipelineState` with any changes. See the [PipelineState and BlockState guide](./modular_diffusers_states.md) for more details. -## Creating a `PipelineBlock` +## Define a `PipelineBlock` To write a `PipelineBlock` class, you need to define a few properties that determine how your block interacts with the pipeline state. Understanding these properties is crucial - they define what data your block can access and what it can produce. The three main properties you need to define are: - `inputs`: Immutable values from the user that cannot be modified - `intermediate_inputs`: Mutable values from previous blocks that can be read and modified -- `intermediate_outputs`: New values your block creates for subsequent blocks +- `intermediate_outputs`: New values your block creates for subsequent blocks and user access Let's explore each one and understand how they work with the pipeline state. @@ -55,7 +57,7 @@ When you list something as an input, you're saying "I need this value directly f This is especially useful for raw values that serve as the "source of truth" in your workflow. For example, with a raw image, many workflows require preprocessing steps like resizing that a previous block might have performed. But in many cases, you also want the raw PIL image. In some inpainting workflows, you need the original image to overlay with the generated result for better control and consistency. -**Intermediate Inputs: Mutable Values from Previous Blocks** +**Intermediate Inputs: Mutable Values from Previous Blocks, or Users** Intermediate inputs are variables your block needs from the mutable pipeline state - these are values that can be read and modified. They're typically created by previous blocks, but could also be directly provided by the user if not the case: @@ -67,9 +69,12 @@ user_intermediate_inputs = [ When you list something as an intermediate input, you're saying "I need this value, but I want to work with a different block that has already created it. I already know for sure that I can get it from this other block, but it's okay if other developers want use something different." -**Intermediate Outputs: New Values for Subsequent Blocks** +**Intermediate Outputs: New Values for Subsequent Blocks and User Access** + +Intermediate outputs are new variables your block creates and adds to the mutable pipeline state. They serve two purposes: -Intermediate outputs are new variables your block creates and adds to the mutable pipeline state so they can be used by subsequent blocks: +1. **For subsequent blocks**: They can be used as intermediate inputs by other blocks in the pipeline +2. **For users**: They become available as final outputs that users can access when running the pipeline ```py user_intermediate_outputs = [ @@ -79,6 +84,8 @@ user_intermediate_outputs = [ Intermediate inputs and intermediate outputs work together like Lego studs and anti-studs - they're the connection points that make blocks modular. When one block produces an intermediate output, it becomes available as an intermediate input for subsequent blocks. This is where the "modular" nature of the system really shines - blocks can be connected and reconnected in different ways as long as their inputs and outputs match. +Additionally, all intermediate outputs are accessible to users when they run the pipeline, typically you would only need the final images, but they are also able to access intermediate results like latents, embeddings, or other processing steps. + **The `__call__` Method Structure** Your `PipelineBlock`'s `__call__` method should follow this structure: @@ -122,7 +129,15 @@ expected_config = [ **Configs**: Pipeline-level settings that control behavior across all blocks. -When you convert your blocks into a pipeline using `blocks.init_pipeline()`, the pipeline collects all component requirements from the blocks and fetches the loading specs from the modular repository. The components are then made available to your block in the `components` argument of the `__call__` method. +When you convert your blocks into a pipeline using `blocks.init_pipeline()`, the pipeline collects all component requirements from the blocks and fetches the loading specs from the modular repository. The components are then made available to your block as the first argument of the `__call__` method. You can access any component you need using dot notation: + +```py +def __call__(self, components, state): + # Access components using dot notation + unet = components.unet + vae = components.vae + scheduler = components.scheduler +``` That's all you need to define in order to create a `PipelineBlock`. There is no hidden complexity. In fact we are going to create a helper function that take exactly these variables as input and return a pipeline block. We will use this helper function through out the tutorial to create test blocks @@ -274,6 +289,4 @@ pipeline_state (after update): PipelineState( - **`image` (inputs)** changed in `block_state` but not in `pipeline_state` - this change is local to the block only. - **`batch_size (intermediate_inputs)`** was updated in both `block_state` and `pipeline_state` - this change affects subsequent blocks (we didn't need to declare it as an intermediate output since it was already in the intermediates dict) - **`image_latents (intermediate_outputs)`** was added to `pipeline_state` because it was declared as an intermediate output - - **`processed_image`** was not added to `pipeline_state` because it wasn't declared as an intermediate output - -Understanding how to create `PipelineBlock`s is fundamental to building modular workflows in Modular Diffusers. Remember that `PipelineBlock`s are definitions/specifications - they define what a block should do and what data it needs, but you need to convert them into a `ModularPipeline` to actually execute them. For information on creating and running pipelines, see the [Modular Pipeline guide](modular_pipeline.md). \ No newline at end of file + - **`processed_image`** was not added to `pipeline_state` because it wasn't declared as an intermediate output \ No newline at end of file From 01300a35f5391725087ae027faefab32f330250c Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 10 Jul 2025 03:17:38 +0200 Subject: [PATCH 170/170] up --- .../stable_diffusion_xl/before_denoise.py | 46 +++++++++++++------ .../stable_diffusion_xl/decoders.py | 21 +++++---- .../stable_diffusion_xl/modular_pipeline.py | 3 +- 3 files changed, 46 insertions(+), 24 deletions(-) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py index b064a74cbfa0..c56f4af1b8a5 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py @@ -19,8 +19,9 @@ import torch from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance from ...image_processor import VaeImageProcessor -from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel +from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel, UNet2DConditionModel from ...pipelines.controlnet.multicontrolnet import MultiControlNetModel from ...schedulers import EulerDiscreteScheduler from ...utils import logging @@ -266,37 +267,37 @@ def intermediate_outputs(self) -> List[str]: OutputParam( "prompt_embeds", type_hint=torch.Tensor, - kwargs_type="guider_input_fields", + kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields description="text embeddings used to guide the image generation", ), OutputParam( "negative_prompt_embeds", type_hint=torch.Tensor, - kwargs_type="guider_input_fields", + kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields description="negative text embeddings used to guide the image generation", ), OutputParam( "pooled_prompt_embeds", type_hint=torch.Tensor, - kwargs_type="guider_input_fields", + kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields description="pooled text embeddings used to guide the image generation", ), OutputParam( "negative_pooled_prompt_embeds", type_hint=torch.Tensor, - kwargs_type="guider_input_fields", + kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields description="negative pooled text embeddings used to guide the image generation", ), OutputParam( "ip_adapter_embeds", type_hint=List[torch.Tensor], - kwargs_type="guider_input_fields", + kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields description="image embeddings for IP-Adapter", ), OutputParam( "negative_ip_adapter_embeds", type_hint=List[torch.Tensor], - kwargs_type="guider_input_fields", + kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields description="negative image embeddings for IP-Adapter", ), ] @@ -683,12 +684,6 @@ def intermediate_outputs(self) -> List[str]: OutputParam( "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" ), - OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for inpainting generation"), - OutputParam( - "masked_image_latents", - type_hint=torch.Tensor, - description="The masked image latents to use for the inpainting generation (only for inpainting-specific unet)", - ), OutputParam( "noise", type_hint=torch.Tensor, @@ -993,6 +988,7 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock): def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("vae", AutoencoderKL), ] @property @@ -1105,6 +1101,18 @@ def expected_configs(self) -> List[ConfigSpec]: ConfigSpec("requires_aesthetics_score", False), ] + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("unet", UNet2DConditionModel), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config", + ), + ] + @property def description(self) -> str: return "Step that prepares the additional conditioning for the image-to-image/inpainting generation process" @@ -1315,6 +1323,18 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): def description(self) -> str: return "Step that prepares the additional conditioning for the text-to-image generation process" + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("unet", UNet2DConditionModel), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config", + ), + ] + @property def inputs(self) -> List[Tuple[str, Any]]: return [ diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py index 878e991dbf63..e9f627636e8c 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py @@ -167,6 +167,17 @@ def description(self) -> str: + "only needed when you are using the `padding_mask_crop` option when pre-processing the image and mask" ) + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config", + ), + ] + @property def inputs(self) -> List[Tuple[str, Any]]: return [ @@ -190,16 +201,6 @@ def intermediate_inputs(self) -> List[str]: ), ] - @property - def intermediate_outputs(self) -> List[str]: - return [ - OutputParam( - "images", - type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], - description="The generated images with the mask overlayed", - ) - ] - @torch.no_grad() def __call__(self, components, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py index 0c45857da742..fc030fae56fb 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py @@ -91,7 +91,8 @@ def num_channels_latents(self): return num_channels_latents -# YiYi Notes: not used yet, maintain a list of schema that can be used across all pipeline blocks +# YiYi/Sayak TODO: not used yet, maintain a list of schema that can be used across all pipeline blocks +# auto_docstring SDXL_INPUTS_SCHEMA = { "prompt": InputParam( "prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"