From c8b5d5641271f88dc9c0ab41ca48e39ef143df3f Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 2 May 2025 00:46:31 +0200 Subject: [PATCH 01/38] make loader optional --- src/diffusers/pipelines/modular_pipeline.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index 636b543395df..c994b91ba8bb 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -199,7 +199,8 @@ def run(self, state: PipelineState = None, output: Union[str, List[str]] = None, state = PipelineState() if not hasattr(self, "loader"): - raise ValueError("Loader is not set, please call `setup_loader()` first.") + logger.warning("Loader is not set, please call `setup_loader()` if you need to load checkpoints for your pipeline.") + self.loader = None # Make a copy of the input kwargs input_params = kwargs.copy() From 7b86fcea31d7c968e774dd16c275f601c2bed0fb Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 2 May 2025 11:31:25 +0200 Subject: [PATCH 02/38] remove lora step and ip-adapter step -> no longer needed --- .../pipeline_stable_diffusion_xl_modular.py | 168 ------------------ 1 file changed, 168 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 5ae9e63851db..0d068f90f7e6 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -140,174 +140,6 @@ def retrieve_latents( -# YiYi Notes: I think we do not need this, we can add loader methods on the components class -class StableDiffusionXLLoraStep(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return ( - "Lora step that handles all the lora related tasks: load/unload lora weights into unet and text encoders, manage lora adapters etc" - " See [StableDiffusionXLLoraLoaderMixin](https://huggingface.co/docs/diffusers/api/loaders/lora#diffusers.loaders.StableDiffusionXLLoraLoaderMixin)" - " for more details" - ) - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("text_encoder", CLIPTextModel), - ComponentSpec("text_encoder_2", CLIPTextModelWithProjection), - ComponentSpec("unet", UNet2DConditionModel), - ] - - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - raise EnvironmentError("StableDiffusionXLLoraStep is desgined to be used to load lora weights, __call__ is not implemented") - - -class StableDiffusionXLIPAdapterStep(PipelineBlock, ModularIPAdapterMixin): - model_name = "stable-diffusion-xl" - - - @property - def description(self) -> str: - return ( - "IP Adapter step that handles all the ip adapter related tasks: Load/unload ip adapter weights into unet, prepare ip adapter image embeddings, etc" - " See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)" - " for more details" - ) - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("image_encoder", CLIPVisionModelWithProjection), - ComponentSpec("feature_extractor", CLIPImageProcessor, config=FrozenDict({"size": 224, "crop_size": 224}), default_creation_method="from_config"), - ComponentSpec("unet", UNet2DConditionModel), - ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config"), - ] - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam( - "ip_adapter_image", - PipelineImageInput, - required=True, - description="The image(s) to be used as ip adapter" - ) - ] - - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [ - OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"), - OutputParam("negative_ip_adapter_embeds", type_hint=torch.Tensor, description="Negative IP adapter image embeddings") - ] - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image with self -> components - def encode_image(self, components, image, device, num_images_per_prompt, output_hidden_states=None): - dtype = next(components.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = components.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = components.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = components.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states - else: - image_embeds = components.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) - - return image_embeds, uncond_image_embeds - - # modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds - def prepare_ip_adapter_image_embeds( - self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, prepare_unconditional_embeds - ): - image_embeds = [] - if prepare_unconditional_embeds: - negative_image_embeds = [] - if ip_adapter_image_embeds is None: - if not isinstance(ip_adapter_image, list): - ip_adapter_image = [ip_adapter_image] - - if len(ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers): - raise ValueError( - f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." - ) - - for single_ip_adapter_image, image_proj_layer in zip( - ip_adapter_image, components.unet.encoder_hid_proj.image_projection_layers - ): - output_hidden_state = not isinstance(image_proj_layer, ImageProjection) - single_image_embeds, single_negative_image_embeds = self.encode_image( - components, single_ip_adapter_image, device, 1, output_hidden_state - ) - - image_embeds.append(single_image_embeds[None, :]) - if prepare_unconditional_embeds: - negative_image_embeds.append(single_negative_image_embeds[None, :]) - else: - for single_image_embeds in ip_adapter_image_embeds: - if prepare_unconditional_embeds: - single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - negative_image_embeds.append(single_negative_image_embeds) - image_embeds.append(single_image_embeds) - - ip_adapter_image_embeds = [] - for i, single_image_embeds in enumerate(image_embeds): - single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) - if prepare_unconditional_embeds: - single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) - - single_image_embeds = single_image_embeds.to(device=device) - ip_adapter_image_embeds.append(single_image_embeds) - - return ip_adapter_image_embeds - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - - data.prepare_unconditional_embeds = pipeline.guider.num_conditions > 1 - data.device = pipeline._execution_device - - data.ip_adapter_embeds = self.prepare_ip_adapter_image_embeds( - pipeline, - ip_adapter_image=data.ip_adapter_image, - ip_adapter_image_embeds=None, - device=data.device, - num_images_per_prompt=1, - prepare_unconditional_embeds=data.prepare_unconditional_embeds, - ) - if data.prepare_unconditional_embeds: - data.negative_ip_adapter_embeds = [] - for i, image_embeds in enumerate(data.ip_adapter_embeds): - negative_image_embeds, image_embeds = image_embeds.chunk(2) - data.negative_ip_adapter_embeds.append(negative_image_embeds) - data.ip_adapter_embeds[i] = image_embeds - - self.add_block_state(state, data) - return pipeline, state - - class StableDiffusionXLTextEncoderStep(PipelineBlock): model_name = "stable-diffusion-xl" From 7ca860c24bc35fccf5a68db2f92af932819f0b24 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 3 May 2025 01:32:59 +0200 Subject: [PATCH 03/38] rename pipeline -> components, data -> block_state --- .../pipeline_stable_diffusion_xl_modular.py | 1554 +++++++++-------- 1 file changed, 872 insertions(+), 682 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 0d068f90f7e6..81808540ee67 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -65,6 +65,51 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# YiYi TODO: move to a different file? stable_diffusion_xl_module should have its own folder? +# YiYi Notes: model specific components: +## (1) it should inherit from ModularLoader +## (2) acts like a container that holds components and configs +## (3) define default config (related to components), e.g. default_sample_size, vae_scale_factor, num_channels_unet, num_channels_latents +## (4) inherit from model-specic loader class (e.g. StableDiffusionXLLoraLoaderMixin) +## (5) how to use together with Components_manager? +class StableDiffusionXLModularLoader( + ModularLoader, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + ModularIPAdapterMixin, +): + @property + def default_sample_size(self): + default_sample_size = 128 + if hasattr(self, "unet") and self.unet is not None: + default_sample_size = self.unet.config.sample_size + return default_sample_size + + @property + def vae_scale_factor(self): + vae_scale_factor = 8 + if hasattr(self, "vae") and self.vae is not None: + vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + return vae_scale_factor + + @property + def num_channels_unet(self): + num_channels_unet = 4 + if hasattr(self, "unet") and self.unet is not None: + num_channels_unet = self.unet.config.in_channels + return num_channels_unet + + @property + def num_channels_latents(self): + num_channels_latents = 4 + if hasattr(self, "vae") and self.vae is not None: + num_channels_latents = self.vae.config.latent_channels + return num_channels_latents + + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, @@ -140,6 +185,148 @@ def retrieve_latents( +class StableDiffusionXLIPAdapterStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + + @property + def description(self) -> str: + return ( + "IP Adapter step that handles all the ip adapter related tasks: Load/unload ip adapter weights into unet, prepare ip adapter image embeddings, etc" + " See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)" + " for more details" + ) + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("image_encoder", CLIPVisionModelWithProjection), + ComponentSpec("feature_extractor", CLIPImageProcessor, config=FrozenDict({"size": 224, "crop_size": 224}), default_creation_method="from_config"), + ComponentSpec("unet", UNet2DConditionModel), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + "ip_adapter_image", + PipelineImageInput, + required=True, + description="The image(s) to be used as ip adapter" + ) + ] + + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [ + OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"), + OutputParam("negative_ip_adapter_embeds", type_hint=torch.Tensor, description="Negative IP adapter image embeddings") + ] + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image with self -> components + @staticmethod + def encode_image(components, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(components.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = components.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = components.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = components.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = components.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, prepare_unconditional_embeds + ): + image_embeds = [] + if prepare_unconditional_embeds: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, components.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + components, single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if prepare_unconditional_embeds: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if prepare_unconditional_embeds: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if prepare_unconditional_embeds: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 + block_state.device = components._execution_device + + block_state.ip_adapter_embeds = self.prepare_ip_adapter_image_embeds( + components, + ip_adapter_image=block_state.ip_adapter_image, + ip_adapter_image_embeds=None, + device=block_state.device, + num_images_per_prompt=1, + prepare_unconditional_embeds=block_state.prepare_unconditional_embeds, + ) + if block_state.prepare_unconditional_embeds: + block_state.negative_ip_adapter_embeds = [] + for i, image_embeds in enumerate(block_state.ip_adapter_embeds): + negative_image_embeds, image_embeds = image_embeds.chunk(2) + block_state.negative_ip_adapter_embeds.append(negative_image_embeds) + block_state.ip_adapter_embeds[i] = image_embeds + + self.add_block_state(state, block_state) + return components, state + + class StableDiffusionXLTextEncoderStep(PipelineBlock): model_name = "stable-diffusion-xl" @@ -189,15 +376,16 @@ def intermediates_outputs(self) -> List[OutputParam]: OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="negative pooled text embeddings used to guide the image generation"), ] - def check_inputs(self, pipeline, data): + @staticmethod + def check_inputs(block_state): - if data.prompt is not None and (not isinstance(data.prompt, str) and not isinstance(data.prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(data.prompt)}") - elif data.prompt_2 is not None and (not isinstance(data.prompt_2, str) and not isinstance(data.prompt_2, list)): - raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(data.prompt_2)}") + if block_state.prompt is not None and (not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}") + elif block_state.prompt_2 is not None and (not isinstance(block_state.prompt_2, str) and not isinstance(block_state.prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(block_state.prompt_2)}") + @staticmethod def encode_prompt( - self, components, prompt: str, prompt_2: Optional[str] = None, @@ -255,7 +443,7 @@ def encode_prompt( Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ - device = device or self._execution_device + device = device or components._execution_device # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it @@ -433,42 +621,42 @@ def encode_prompt( @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: # Get inputs and intermediates - data = self.get_block_state(state) - self.check_inputs(pipeline, data) + block_state = self.get_block_state(state) + self.check_inputs(block_state) - data.prepare_unconditional_embeds = pipeline.guider.num_conditions > 1 - data.device = pipeline._execution_device + block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 + block_state.device = components._execution_device # Encode input prompt - data.text_encoder_lora_scale = ( - data.cross_attention_kwargs.get("scale", None) if data.cross_attention_kwargs is not None else None + block_state.text_encoder_lora_scale = ( + block_state.cross_attention_kwargs.get("scale", None) if block_state.cross_attention_kwargs is not None else None ) ( - data.prompt_embeds, - data.negative_prompt_embeds, - data.pooled_prompt_embeds, - data.negative_pooled_prompt_embeds, + block_state.prompt_embeds, + block_state.negative_prompt_embeds, + block_state.pooled_prompt_embeds, + block_state.negative_pooled_prompt_embeds, ) = self.encode_prompt( - pipeline, - data.prompt, - data.prompt_2, - data.device, + components, + block_state.prompt, + block_state.prompt_2, + block_state.device, 1, - data.prepare_unconditional_embeds, - data.negative_prompt, - data.negative_prompt_2, + block_state.prepare_unconditional_embeds, + block_state.negative_prompt, + block_state.negative_prompt_2, prompt_embeds=None, negative_prompt_embeds=None, pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None, - lora_scale=data.text_encoder_lora_scale, - clip_skip=data.clip_skip, + lora_scale=block_state.text_encoder_lora_scale, + clip_skip=block_state.clip_skip, ) # Add outputs - self.add_block_state(state, data) - return pipeline, state + self.add_block_state(state, block_state) + return components, state class StableDiffusionXLVaeEncoderStep(PipelineBlock): @@ -552,30 +740,30 @@ def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Ge @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - data.preprocess_kwargs = data.preprocess_kwargs or {} - data.device = pipeline._execution_device - data.dtype = data.dtype if data.dtype is not None else pipeline.vae.dtype + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.preprocess_kwargs = block_state.preprocess_kwargs or {} + block_state.device = components._execution_device + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype - data.image = pipeline.image_processor.preprocess(data.image, height=data.height, width=data.width, **data.preprocess_kwargs) - data.image = data.image.to(device=data.device, dtype=data.dtype) + block_state.image = components.image_processor.preprocess(block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs) + block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype) - data.batch_size = data.image.shape[0] + block_state.batch_size = block_state.image.shape[0] # if generator is a list, make sure the length of it matches the length of images (both should be batch_size) - if isinstance(data.generator, list) and len(data.generator) != data.batch_size: + if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size: raise ValueError( - f"You have passed a list of generators of length {len(data.generator)}, but requested an effective batch" - f" size of {data.batch_size}. Make sure the batch size matches the length of the generators." + f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch" + f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators." ) - data.image_latents = self._encode_vae_image(pipeline,image=data.image, generator=data.generator) + block_state.image_latents = self._encode_vae_image(components, image=block_state.image, generator=block_state.generator) - self.add_block_state(state, data) + self.add_block_state(state, block_state) - return pipeline, state + return components, state class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): @@ -715,47 +903,47 @@ def prepare_mask_latents( @torch.no_grad() - def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) + block_state = self.get_block_state(state) - data.dtype = data.dtype if data.dtype is not None else pipeline.vae.dtype - data.device = pipeline._execution_device + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + block_state.device = components._execution_device - if data.padding_mask_crop is not None: - data.crops_coords = pipeline.mask_processor.get_crop_region(data.mask_image, data.width, data.height, pad=data.padding_mask_crop) - data.resize_mode = "fill" + if block_state.padding_mask_crop is not None: + block_state.crops_coords = components.mask_processor.get_crop_region(block_state.mask_image, block_state.width, block_state.height, pad=block_state.padding_mask_crop) + block_state.resize_mode = "fill" else: - data.crops_coords = None - data.resize_mode = "default" + block_state.crops_coords = None + block_state.resize_mode = "default" - data.image = pipeline.image_processor.preprocess(data.image, height=data.height, width=data.width, crops_coords=data.crops_coords, resize_mode=data.resize_mode) - data.image = data.image.to(dtype=torch.float32) + block_state.image = components.image_processor.preprocess(block_state.image, height=block_state.height, width=block_state.width, crops_coords=block_state.crops_coords, resize_mode=block_state.resize_mode) + block_state.image = block_state.image.to(dtype=torch.float32) - data.mask = pipeline.mask_processor.preprocess(data.mask_image, height=data.height, width=data.width, resize_mode=data.resize_mode, crops_coords=data.crops_coords) - data.masked_image = data.image * (data.mask < 0.5) + block_state.mask = components.mask_processor.preprocess(block_state.mask_image, height=block_state.height, width=block_state.width, resize_mode=block_state.resize_mode, crops_coords=block_state.crops_coords) + block_state.masked_image = block_state.image * (block_state.mask < 0.5) - data.batch_size = data.image.shape[0] - data.image = data.image.to(device=data.device, dtype=data.dtype) - data.image_latents = self._encode_vae_image(pipeline, image=data.image, generator=data.generator) + block_state.batch_size = block_state.image.shape[0] + block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype) + block_state.image_latents = self._encode_vae_image(components, image=block_state.image, generator=block_state.generator) # 7. Prepare mask latent variables - data.mask, data.masked_image_latents = self.prepare_mask_latents( - pipeline, - data.mask, - data.masked_image, - data.batch_size, - data.height, - data.width, - data.dtype, - data.device, - data.generator, + block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents( + components, + block_state.mask, + block_state.masked_image, + block_state.batch_size, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, ) - self.add_block_state(state, data) + self.add_block_state(state, block_state) - return pipeline, state + return components, state class StableDiffusionXLInputStep(PipelineBlock): @@ -802,77 +990,77 @@ def intermediates_outputs(self) -> List[str]: OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="negative image embeddings for IP-Adapter"), ] - def check_inputs(self, pipeline, data): + def check_inputs(self, components, block_state): - if data.prompt_embeds is not None and data.negative_prompt_embeds is not None: - if data.prompt_embeds.shape != data.negative_prompt_embeds.shape: + if block_state.prompt_embeds is not None and block_state.negative_prompt_embeds is not None: + if block_state.prompt_embeds.shape != block_state.negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {data.prompt_embeds.shape} != `negative_prompt_embeds`" - f" {data.negative_prompt_embeds.shape}." + f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `negative_prompt_embeds`" + f" {block_state.negative_prompt_embeds.shape}." ) - if data.prompt_embeds is not None and data.pooled_prompt_embeds is None: + if block_state.prompt_embeds is not None and block_state.pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." ) - if data.negative_prompt_embeds is not None and data.negative_pooled_prompt_embeds is None: + if block_state.negative_prompt_embeds is not None and block_state.negative_pooled_prompt_embeds is None: raise ValueError( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) - if data.ip_adapter_embeds is not None and not isinstance(data.ip_adapter_embeds, list): + if block_state.ip_adapter_embeds is not None and not isinstance(block_state.ip_adapter_embeds, list): raise ValueError("`ip_adapter_embeds` must be a list") - if data.negative_ip_adapter_embeds is not None and not isinstance(data.negative_ip_adapter_embeds, list): + if block_state.negative_ip_adapter_embeds is not None and not isinstance(block_state.negative_ip_adapter_embeds, list): raise ValueError("`negative_ip_adapter_embeds` must be a list") - if data.ip_adapter_embeds is not None and data.negative_ip_adapter_embeds is not None: - for i, ip_adapter_embed in enumerate(data.ip_adapter_embeds): - if ip_adapter_embed.shape != data.negative_ip_adapter_embeds[i].shape: + if block_state.ip_adapter_embeds is not None and block_state.negative_ip_adapter_embeds is not None: + for i, ip_adapter_embed in enumerate(block_state.ip_adapter_embeds): + if ip_adapter_embed.shape != block_state.negative_ip_adapter_embeds[i].shape: raise ValueError( "`ip_adapter_embeds` and `negative_ip_adapter_embeds` must have the same shape when passed directly, but" f" got: `ip_adapter_embeds` {ip_adapter_embed.shape} != `negative_ip_adapter_embeds`" - f" {data.negative_ip_adapter_embeds[i].shape}." + f" {block_state.negative_ip_adapter_embeds[i].shape}." ) @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - self.check_inputs(pipeline, data) + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) - data.batch_size = data.prompt_embeds.shape[0] - data.dtype = data.prompt_embeds.dtype + block_state.batch_size = block_state.prompt_embeds.shape[0] + block_state.dtype = block_state.prompt_embeds.dtype - _, seq_len, _ = data.prompt_embeds.shape + _, seq_len, _ = block_state.prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method - data.prompt_embeds = data.prompt_embeds.repeat(1, data.num_images_per_prompt, 1) - data.prompt_embeds = data.prompt_embeds.view(data.batch_size * data.num_images_per_prompt, seq_len, -1) + block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.prompt_embeds = block_state.prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1) - if data.negative_prompt_embeds is not None: - _, seq_len, _ = data.negative_prompt_embeds.shape - data.negative_prompt_embeds = data.negative_prompt_embeds.repeat(1, data.num_images_per_prompt, 1) - data.negative_prompt_embeds = data.negative_prompt_embeds.view(data.batch_size * data.num_images_per_prompt, seq_len, -1) + if block_state.negative_prompt_embeds is not None: + _, seq_len, _ = block_state.negative_prompt_embeds.shape + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1) - data.pooled_prompt_embeds = data.pooled_prompt_embeds.repeat(1, data.num_images_per_prompt, 1) - data.pooled_prompt_embeds = data.pooled_prompt_embeds.view(data.batch_size * data.num_images_per_prompt, -1) + block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1) - if data.negative_pooled_prompt_embeds is not None: - data.negative_pooled_prompt_embeds = data.negative_pooled_prompt_embeds.repeat(1, data.num_images_per_prompt, 1) - data.negative_pooled_prompt_embeds = data.negative_pooled_prompt_embeds.view(data.batch_size * data.num_images_per_prompt, -1) + if block_state.negative_pooled_prompt_embeds is not None: + block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1) - if data.ip_adapter_embeds is not None: - for i, ip_adapter_embed in enumerate(data.ip_adapter_embeds): - data.ip_adapter_embeds[i] = torch.cat([ip_adapter_embed] * data.num_images_per_prompt, dim=0) + if block_state.ip_adapter_embeds is not None: + for i, ip_adapter_embed in enumerate(block_state.ip_adapter_embeds): + block_state.ip_adapter_embeds[i] = torch.cat([ip_adapter_embed] * block_state.num_images_per_prompt, dim=0) - if data.negative_ip_adapter_embeds is not None: - for i, negative_ip_adapter_embed in enumerate(data.negative_ip_adapter_embeds): - data.negative_ip_adapter_embeds[i] = torch.cat([negative_ip_adapter_embed] * data.num_images_per_prompt, dim=0) + if block_state.negative_ip_adapter_embeds is not None: + for i, negative_ip_adapter_embed in enumerate(block_state.negative_ip_adapter_embeds): + block_state.negative_ip_adapter_embeds[i] = torch.cat([negative_ip_adapter_embed] * block_state.num_images_per_prompt, dim=0) - self.add_block_state(state, data) + self.add_block_state(state, block_state) - return pipeline, state + return components, state class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): @@ -961,40 +1149,40 @@ def get_timesteps(self, components, num_inference_steps, strength, device, denoi @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) - data.device = pipeline._execution_device + block_state.device = components._execution_device - data.timesteps, data.num_inference_steps = retrieve_timesteps( - pipeline.scheduler, data.num_inference_steps, data.device, data.timesteps, data.sigmas + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + components.scheduler, block_state.num_inference_steps, block_state.device, block_state.timesteps, block_state.sigmas ) def denoising_value_valid(dnv): return isinstance(dnv, float) and 0 < dnv < 1 - data.timesteps, data.num_inference_steps = self.get_timesteps( - pipeline, - data.num_inference_steps, - data.strength, - data.device, - denoising_start=data.denoising_start if denoising_value_valid(data.denoising_start) else None, + block_state.timesteps, block_state.num_inference_steps = self.get_timesteps( + components, + block_state.num_inference_steps, + block_state.strength, + block_state.device, + denoising_start=block_state.denoising_start if denoising_value_valid(block_state.denoising_start) else None, ) - data.latent_timestep = data.timesteps[:1].repeat(data.batch_size * data.num_images_per_prompt) + block_state.latent_timestep = block_state.timesteps[:1].repeat(block_state.batch_size * block_state.num_images_per_prompt) - if data.denoising_end is not None and isinstance(data.denoising_end, float) and data.denoising_end > 0 and data.denoising_end < 1: - data.discrete_timestep_cutoff = int( + if block_state.denoising_end is not None and isinstance(block_state.denoising_end, float) and block_state.denoising_end > 0 and block_state.denoising_end < 1: + block_state.discrete_timestep_cutoff = int( round( - pipeline.scheduler.config.num_train_timesteps - - (data.denoising_end * pipeline.scheduler.config.num_train_timesteps) + components.scheduler.config.num_train_timesteps + - (block_state.denoising_end * components.scheduler.config.num_train_timesteps) ) ) - data.num_inference_steps = len(list(filter(lambda ts: ts >= data.discrete_timestep_cutoff, data.timesteps))) - data.timesteps = data.timesteps[:data.num_inference_steps] + block_state.num_inference_steps = len(list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps))) + block_state.timesteps = block_state.timesteps[:block_state.num_inference_steps] - self.add_block_state(state, data) + self.add_block_state(state, block_state) - return pipeline, state + return components, state class StableDiffusionXLSetTimestepsStep(PipelineBlock): @@ -1029,27 +1217,27 @@ def intermediates_outputs(self) -> List[OutputParam]: @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) - data.device = pipeline._execution_device + block_state.device = components._execution_device - data.timesteps, data.num_inference_steps = retrieve_timesteps( - pipeline.scheduler, data.num_inference_steps, data.device, data.timesteps, data.sigmas + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + components.scheduler, block_state.num_inference_steps, block_state.device, block_state.timesteps, block_state.sigmas ) - if data.denoising_end is not None and isinstance(data.denoising_end, float) and data.denoising_end > 0 and data.denoising_end < 1: - data.discrete_timestep_cutoff = int( + if block_state.denoising_end is not None and isinstance(block_state.denoising_end, float) and block_state.denoising_end > 0 and block_state.denoising_end < 1: + block_state.discrete_timestep_cutoff = int( round( - pipeline.scheduler.config.num_train_timesteps - - (data.denoising_end * pipeline.scheduler.config.num_train_timesteps) + components.scheduler.config.num_train_timesteps + - (block_state.denoising_end * components.scheduler.config.num_train_timesteps) ) ) - data.num_inference_steps = len(list(filter(lambda ts: ts >= data.discrete_timestep_cutoff, data.timesteps))) - data.timesteps = data.timesteps[:data.num_inference_steps] + block_state.num_inference_steps = len(list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps))) + block_state.timesteps = block_state.timesteps[:block_state.num_inference_steps] - self.add_block_state(state, data) - return pipeline, state + self.add_block_state(state, block_state) + return components, state class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): @@ -1133,7 +1321,46 @@ def intermediates_outputs(self) -> List[str]: OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting generation (only for inpainting-specific unet)"), OutputParam("noise", type_hint=torch.Tensor, description="The noise added to the image latents, used for inpainting generation")] - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents with self -> components + + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components + # YiYi TODO: update the _encode_vae_image so that we can use #Coped from + @staticmethod + def _encode_vae_image(components, image: torch.Tensor, generator: torch.Generator): + + latents_mean = latents_std = None + if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: + latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: + latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) + + dtype = image.dtype + if components.vae.config.force_upcast: + image = image.float() + components.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(components.vae.encode(image), generator=generator) + + if components.vae.config.force_upcast: + components.vae.to(dtype) + + image_latents = image_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) + latents_std = latents_std.to(device=image_latents.device, dtype=dtype) + image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std + else: + image_latents = components.vae.config.scaling_factor * image_latents + + return image_latents + + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents adding components as first argument def prepare_latents_inpaint( self, components, @@ -1252,58 +1479,58 @@ def prepare_mask_latents( @torch.no_grad() - def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) - data.dtype = data.dtype if data.dtype is not None else pipeline.vae.dtype - data.device = pipeline._execution_device + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + block_state.device = components._execution_device - data.is_strength_max = data.strength == 1.0 + block_state.is_strength_max = block_state.strength == 1.0 # for non-inpainting specific unet, we do not need masked_image_latents - if hasattr(pipeline,"unet") and pipeline.unet is not None: - if pipeline.unet.config.in_channels == 4: - data.masked_image_latents = None - - data.add_noise = True if data.denoising_start is None else False - - data.height = data.image_latents.shape[-2] * pipeline.vae_scale_factor - data.width = data.image_latents.shape[-1] * pipeline.vae_scale_factor - - data.latents, data.noise = self.prepare_latents_inpaint( - pipeline, - data.batch_size * data.num_images_per_prompt, - pipeline.num_channels_latents, - data.height, - data.width, - data.dtype, - data.device, - data.generator, - data.latents, - image=data.image_latents, - timestep=data.latent_timestep, - is_strength_max=data.is_strength_max, - add_noise=data.add_noise, + if hasattr(components,"unet") and components.unet is not None: + if components.unet.config.in_channels == 4: + block_state.masked_image_latents = None + + block_state.add_noise = True if block_state.denoising_start is None else False + + block_state.height = block_state.image_latents.shape[-2] * components.vae_scale_factor + block_state.width = block_state.image_latents.shape[-1] * components.vae_scale_factor + + block_state.latents, block_state.noise = self.prepare_latents_inpaint( + components, + block_state.batch_size * block_state.num_images_per_prompt, + components.num_channels_latents, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.latents, + image=block_state.image_latents, + timestep=block_state.latent_timestep, + is_strength_max=block_state.is_strength_max, + add_noise=block_state.add_noise, return_noise=True, return_image_latents=False, ) # 7. Prepare mask latent variables - data.mask, data.masked_image_latents = self.prepare_mask_latents( - pipeline, - data.mask, - data.masked_image_latents, - data.batch_size * data.num_images_per_prompt, - data.height, - data.width, - data.dtype, - data.device, - data.generator, + block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents( + components, + block_state.mask, + block_state.masked_image_latents, + block_state.batch_size * block_state.num_images_per_prompt, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, ) - self.add_block_state(state, data) + self.add_block_state(state, block_state) - return pipeline, state + return components, state class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): @@ -1343,21 +1570,17 @@ def intermediates_inputs(self) -> List[InputParam]: def intermediates_outputs(self) -> List[OutputParam]: return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process")] - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents with self -> components + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents with self -> components # YiYi TODO: refactor using _encode_vae_image + @staticmethod def prepare_latents_img2img( - self, components, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True + components, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True ): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) - # Offload text encoder if `enable_model_cpu_offload` was enabled - if hasattr(components, "final_offload_hook") and components.final_offload_hook is not None: - components.text_encoder_2.to("cpu") - torch.cuda.empty_cache() - image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt @@ -1431,28 +1654,28 @@ def prepare_latents_img2img( return latents @torch.no_grad() - def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - - data.dtype = data.dtype if data.dtype is not None else pipeline.vae.dtype - data.device = pipeline._execution_device - data.add_noise = True if data.denoising_start is None else False - if data.latents is None: - data.latents = self.prepare_latents_img2img( - pipeline, - data.image_latents, - data.latent_timestep, - data.batch_size, - data.num_images_per_prompt, - data.dtype, - data.device, - data.generator, - data.add_noise, + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + block_state.device = components._execution_device + block_state.add_noise = True if block_state.denoising_start is None else False + if block_state.latents is None: + block_state.latents = self.prepare_latents_img2img( + components, + block_state.image_latents, + block_state.latent_timestep, + block_state.batch_size, + block_state.num_images_per_prompt, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.add_noise, ) - self.add_block_state(state, data) + self.add_block_state(state, block_state) - return pipeline, state + return components, state class StableDiffusionXLPrepareLatentsStep(PipelineBlock): @@ -1508,19 +1731,20 @@ def intermediates_outputs(self) -> List[OutputParam]: @staticmethod - def check_inputs(pipeline, data): + def check_inputs(components, block_state): if ( - data.height is not None - and data.height % pipeline.vae_scale_factor != 0 - or data.width is not None - and data.width % pipeline.vae_scale_factor != 0 + block_state.height is not None + and block_state.height % components.vae_scale_factor != 0 + or block_state.width is not None + and block_state.width % components.vae_scale_factor != 0 ): raise ValueError( - f"`height` and `width` have to be divisible by {pipeline.vae_scale_factor} but are {data.height} and {data.width}." + f"`height` and `width` have to be divisible by {components.vae_scale_factor} but are {block_state.height} and {block_state.width}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with self -> components - def prepare_latents(self, components, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + @staticmethod + def prepare_latents(components, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, @@ -1544,34 +1768,34 @@ def prepare_latents(self, components, batch_size, num_channels_latents, height, @torch.no_grad() - def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - - if data.dtype is None: - data.dtype = pipeline.vae.dtype - - data.device = pipeline._execution_device - - self.check_inputs(pipeline, data) - - data.height = data.height or pipeline.default_sample_size * pipeline.vae_scale_factor - data.width = data.width or pipeline.default_sample_size * pipeline.vae_scale_factor - data.num_channels_latents = pipeline.num_channels_latents - data.latents = self.prepare_latents( - pipeline, - data.batch_size * data.num_images_per_prompt, - data.num_channels_latents, - data.height, - data.width, - data.dtype, - data.device, - data.generator, - data.latents, + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + if block_state.dtype is None: + block_state.dtype = components.vae.dtype + + block_state.device = components._execution_device + + self.check_inputs(components, block_state) + + block_state.height = block_state.height or components.default_sample_size * components.vae_scale_factor + block_state.width = block_state.width or components.default_sample_size * components.vae_scale_factor + block_state.num_channels_latents = components.num_channels_latents + block_state.latents = self.prepare_latents( + components, + block_state.batch_size * block_state.num_images_per_prompt, + block_state.num_channels_latents, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.latents, ) - self.add_block_state(state, data) + self.add_block_state(state, block_state) - return pipeline, state + return components, state class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): @@ -1617,8 +1841,8 @@ def intermediates_outputs(self) -> List[OutputParam]: OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids with self -> components + @staticmethod def _get_add_time_ids_img2img( - self, components, original_size, crops_coords_top_left, @@ -1670,8 +1894,9 @@ def _get_add_time_ids_img2img( return add_time_ids, add_neg_time_ids # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + @staticmethod def get_guidance_scale_embedding( - self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 ) -> torch.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 @@ -1701,57 +1926,57 @@ def get_guidance_scale_embedding( return emb @torch.no_grad() - def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - data.device = pipeline._execution_device - - data.vae_scale_factor = pipeline.vae_scale_factor - - data.height, data.width = data.latents.shape[-2:] - data.height = data.height * data.vae_scale_factor - data.width = data.width * data.vae_scale_factor - - data.original_size = data.original_size or (data.height, data.width) - data.target_size = data.target_size or (data.height, data.width) - - data.text_encoder_projection_dim = int(data.pooled_prompt_embeds.shape[-1]) - - if data.negative_original_size is None: - data.negative_original_size = data.original_size - if data.negative_target_size is None: - data.negative_target_size = data.target_size - - data.add_time_ids, data.negative_add_time_ids = self._get_add_time_ids_img2img( - pipeline, - data.original_size, - data.crops_coords_top_left, - data.target_size, - data.aesthetic_score, - data.negative_aesthetic_score, - data.negative_original_size, - data.negative_crops_coords_top_left, - data.negative_target_size, - dtype=data.pooled_prompt_embeds.dtype, - text_encoder_projection_dim=data.text_encoder_projection_dim, + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + + block_state.vae_scale_factor = components.vae_scale_factor + + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * block_state.vae_scale_factor + block_state.width = block_state.width * block_state.vae_scale_factor + + block_state.original_size = block_state.original_size or (block_state.height, block_state.width) + block_state.target_size = block_state.target_size or (block_state.height, block_state.width) + + block_state.text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1]) + + if block_state.negative_original_size is None: + block_state.negative_original_size = block_state.original_size + if block_state.negative_target_size is None: + block_state.negative_target_size = block_state.target_size + + block_state.add_time_ids, block_state.negative_add_time_ids = self._get_add_time_ids_img2img( + components, + block_state.original_size, + block_state.crops_coords_top_left, + block_state.target_size, + block_state.aesthetic_score, + block_state.negative_aesthetic_score, + block_state.negative_original_size, + block_state.negative_crops_coords_top_left, + block_state.negative_target_size, + dtype=block_state.pooled_prompt_embeds.dtype, + text_encoder_projection_dim=block_state.text_encoder_projection_dim, ) - data.add_time_ids = data.add_time_ids.repeat(data.batch_size * data.num_images_per_prompt, 1).to(device=data.device) - data.negative_add_time_ids = data.negative_add_time_ids.repeat(data.batch_size * data.num_images_per_prompt, 1).to(device=data.device) + block_state.add_time_ids = block_state.add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) + block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) # Optionally get Guidance Scale Embedding for LCM - data.timestep_cond = None + block_state.timestep_cond = None if ( - hasattr(pipeline, "unet") - and pipeline.unet is not None - and pipeline.unet.config.time_cond_proj_dim is not None + hasattr(components, "unet") + and components.unet is not None + and components.unet.config.time_cond_proj_dim is not None ): # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! - data.guidance_scale_tensor = torch.tensor(pipeline.guider.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt) - data.timestep_cond = self.get_guidance_scale_embedding( - data.guidance_scale_tensor, embedding_dim=pipeline.unet.config.time_cond_proj_dim - ).to(device=data.device, dtype=data.latents.dtype) + block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat(block_state.batch_size * block_state.num_images_per_prompt) + block_state.timestep_cond = self.get_guidance_scale_embedding( + block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim + ).to(device=block_state.device, dtype=block_state.latents.dtype) - self.add_block_state(state, data) - return pipeline, state + self.add_block_state(state, block_state) + return components, state class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): @@ -1805,8 +2030,9 @@ def intermediates_outputs(self) -> List[OutputParam]: OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids with self -> components + @staticmethod def _get_add_time_ids( - self, components, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + components, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None ): add_time_ids = list(original_size + crops_coords_top_left + target_size) @@ -1824,8 +2050,9 @@ def _get_add_time_ids( return add_time_ids # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + @staticmethod def get_guidance_scale_embedding( - self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 ) -> torch.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 @@ -1855,57 +2082,57 @@ def get_guidance_scale_embedding( return emb @torch.no_grad() - def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - data.device = pipeline._execution_device - - data.height, data.width = data.latents.shape[-2:] - data.height = data.height * pipeline.vae_scale_factor - data.width = data.width * pipeline.vae_scale_factor - - data.original_size = data.original_size or (data.height, data.width) - data.target_size = data.target_size or (data.height, data.width) - - data.text_encoder_projection_dim = int(data.pooled_prompt_embeds.shape[-1]) - - data.add_time_ids = self._get_add_time_ids( - pipeline, - data.original_size, - data.crops_coords_top_left, - data.target_size, - data.pooled_prompt_embeds.dtype, - text_encoder_projection_dim=data.text_encoder_projection_dim, + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * components.vae_scale_factor + block_state.width = block_state.width * components.vae_scale_factor + + block_state.original_size = block_state.original_size or (block_state.height, block_state.width) + block_state.target_size = block_state.target_size or (block_state.height, block_state.width) + + block_state.text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1]) + + block_state.add_time_ids = self._get_add_time_ids( + components, + block_state.original_size, + block_state.crops_coords_top_left, + block_state.target_size, + block_state.pooled_prompt_embeds.dtype, + text_encoder_projection_dim=block_state.text_encoder_projection_dim, ) - if data.negative_original_size is not None and data.negative_target_size is not None: - data.negative_add_time_ids = self._get_add_time_ids( - pipeline, - data.negative_original_size, - data.negative_crops_coords_top_left, - data.negative_target_size, - data.pooled_prompt_embeds.dtype, - text_encoder_projection_dim=data.text_encoder_projection_dim, + if block_state.negative_original_size is not None and block_state.negative_target_size is not None: + block_state.negative_add_time_ids = self._get_add_time_ids( + components, + block_state.negative_original_size, + block_state.negative_crops_coords_top_left, + block_state.negative_target_size, + block_state.pooled_prompt_embeds.dtype, + text_encoder_projection_dim=block_state.text_encoder_projection_dim, ) else: - data.negative_add_time_ids = data.add_time_ids + block_state.negative_add_time_ids = block_state.add_time_ids - data.add_time_ids = data.add_time_ids.repeat(data.batch_size * data.num_images_per_prompt, 1).to(device=data.device) - data.negative_add_time_ids = data.negative_add_time_ids.repeat(data.batch_size * data.num_images_per_prompt, 1).to(device=data.device) + block_state.add_time_ids = block_state.add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) + block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) # Optionally get Guidance Scale Embedding for LCM - data.timestep_cond = None + block_state.timestep_cond = None if ( - hasattr(pipeline, "unet") - and pipeline.unet is not None - and pipeline.unet.config.time_cond_proj_dim is not None + hasattr(components, "unet") + and components.unet is not None + and components.unet.config.time_cond_proj_dim is not None ): # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! - data.guidance_scale_tensor = torch.tensor(pipeline.guider.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt) - data.timestep_cond = self.get_guidance_scale_embedding( - data.guidance_scale_tensor, embedding_dim=pipeline.unet.config.time_cond_proj_dim - ).to(device=data.device, dtype=data.latents.dtype) + block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat(block_state.batch_size * block_state.num_images_per_prompt) + block_state.timestep_cond = self.get_guidance_scale_embedding( + block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim + ).to(device=block_state.device, dtype=block_state.latents.dtype) - self.add_block_state(state, data) - return pipeline, state + self.add_block_state(state, block_state) + return components, state class StableDiffusionXLDenoiseStep(PipelineBlock): @@ -2041,27 +2268,29 @@ def intermediates_outputs(self) -> List[OutputParam]: return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - def check_inputs(self, pipeline, data): + @staticmethod + def check_inputs(components, block_state): - num_channels_unet = pipeline.unet.config.in_channels + num_channels_unet = components.unet.config.in_channels if num_channels_unet == 9: # default case for runwayml/stable-diffusion-inpainting - if data.mask is None or data.masked_image_latents is None: + if block_state.mask is None or block_state.masked_image_latents is None: raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") - num_channels_latents = data.latents.shape[1] - num_channels_mask = data.mask.shape[1] - num_channels_masked_image = data.masked_image_latents.shape[1] + num_channels_latents = block_state.latents.shape[1] + num_channels_mask = block_state.mask.shape[1] + num_channels_masked_image = block_state.masked_image_latents.shape[1] if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: raise ValueError( - f"Incorrect configuration settings! The config of `pipeline.unet`: {pipeline.unet.config} expects" - f" {pipeline.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" + f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" - " `pipeline.unet` or your `mask_image` or `image` input." + " `components.unet` or your `mask_image` or `image` input." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components - def prepare_extra_step_kwargs(self, components, generator, eta): + @staticmethod + def prepare_extra_step_kwargs(components, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 @@ -2079,42 +2308,42 @@ def prepare_extra_step_kwargs(self, components, generator, eta): return extra_step_kwargs @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - self.check_inputs(pipeline, data) + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) - data.num_channels_unet = pipeline.unet.config.in_channels - data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False - if data.disable_guidance: - pipeline.guider.disable() + block_state.num_channels_unet = components.unet.config.in_channels + block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False + if block_state.disable_guidance: + components.guider.disable() else: - pipeline.guider.enable() + components.guider.enable() # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) - data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0) + block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) + block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) - pipeline.guider.set_input_fields( + components.guider.set_input_fields( prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), add_time_ids=("add_time_ids", "negative_add_time_ids"), pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), ) - with self.progress_bar(total=data.num_inference_steps) as progress_bar: - for i, t in enumerate(data.timesteps): - pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) - guider_data = pipeline.guider.prepare_inputs(data) + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + guider_data = components.guider.prepare_inputs(block_state) - data.scaled_latents = pipeline.scheduler.scale_model_input(data.latents, t) + block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) # Prepare for inpainting - if data.num_channels_unet == 9: - data.scaled_latents = torch.cat([data.scaled_latents, data.mask, data.masked_image_latents], dim=1) + if block_state.num_channels_unet == 9: + block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) for batch in guider_data: - pipeline.guider.prepare_models(pipeline.unet) + components.guider.prepare_models(components.unet) # Prepare additional conditionings batch.added_cond_kwargs = { @@ -2125,45 +2354,45 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds # Predict the noise residual - batch.noise_pred = pipeline.unet( - data.scaled_latents, + batch.noise_pred = components.unet( + block_state.scaled_latents, t, encoder_hidden_states=batch.prompt_embeds, - timestep_cond=data.timestep_cond, - cross_attention_kwargs=data.cross_attention_kwargs, + timestep_cond=block_state.timestep_cond, + cross_attention_kwargs=block_state.cross_attention_kwargs, added_cond_kwargs=batch.added_cond_kwargs, return_dict=False, )[0] - pipeline.guider.cleanup_models(pipeline.unet) + components.guider.cleanup_models(components.unet) # Perform guidance - data.noise_pred, scheduler_step_kwargs = pipeline.guider(guider_data) + block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_data) # Perform scheduler step using the predicted output - data.latents_dtype = data.latents.dtype - data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] + block_state.latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] - if data.latents.dtype != data.latents_dtype: + if block_state.latents.dtype != block_state.latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - data.latents = data.latents.to(data.latents_dtype) + block_state.latents = block_state.latents.to(block_state.latents_dtype) - if data.num_channels_unet == 4 and data.mask is not None and data.image_latents is not None: - data.init_latents_proper = data.image_latents - if i < len(data.timesteps) - 1: - data.noise_timestep = data.timesteps[i + 1] - data.init_latents_proper = pipeline.scheduler.add_noise( - data.init_latents_proper, data.noise, torch.tensor([data.noise_timestep]) + if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: + block_state.init_latents_proper = block_state.image_latents + if i < len(block_state.timesteps) - 1: + block_state.noise_timestep = block_state.timesteps[i + 1] + block_state.init_latents_proper = components.scheduler.add_noise( + block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) ) - data.latents = (1 - data.mask) * data.init_latents_proper + data.mask * data.latents + block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents - if i == len(data.timesteps) - 1 or ((i + 1) > data.num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): + if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): progress_bar.update() - self.add_block_state(state, data) + self.add_block_state(state, block_state) - return pipeline, state + return components, state class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): @@ -2308,30 +2537,31 @@ def intermediates_inputs(self) -> List[str]: def intermediates_outputs(self) -> List[OutputParam]: return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - def check_inputs(self, pipeline, data): + @staticmethod + def check_inputs(components, block_state): - num_channels_unet = pipeline.unet.config.in_channels + num_channels_unet = components.unet.config.in_channels if num_channels_unet == 9: # default case for runwayml/stable-diffusion-inpainting - if data.mask is None or data.masked_image_latents is None: + if block_state.mask is None or block_state.masked_image_latents is None: raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") - num_channels_latents = data.latents.shape[1] - num_channels_mask = data.mask.shape[1] - num_channels_masked_image = data.masked_image_latents.shape[1] + num_channels_latents = block_state.latents.shape[1] + num_channels_mask = block_state.mask.shape[1] + num_channels_masked_image = block_state.masked_image_latents.shape[1] if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: raise ValueError( - f"Incorrect configuration settings! The config of `pipeline.unet`: {pipeline.unet.config} expects" - f" {pipeline.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" + f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" - " `pipeline.unet` or your `mask_image` or `image` input." + " `components.unet` or your `mask_image` or `image` input." ) # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image # 1. return image without apply any guidance # 2. add crops_coords and resize_mode to preprocess() + @staticmethod def prepare_control_image( - self, components, image, width, @@ -2359,7 +2589,8 @@ def prepare_control_image( return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components - def prepare_extra_step_kwargs(self, components, generator, eta): + @staticmethod + def prepare_extra_step_kwargs(components, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 @@ -2378,108 +2609,108 @@ def prepare_extra_step_kwargs(self, components, generator, eta): @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - self.check_inputs(pipeline, data) + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) - data.num_channels_unet = pipeline.unet.config.in_channels + block_state.num_channels_unet = components.unet.config.in_channels # (1) prepare controlnet inputs - data.device = pipeline._execution_device - data.height, data.width = data.latents.shape[-2:] - data.height = data.height * pipeline.vae_scale_factor - data.width = data.width * pipeline.vae_scale_factor + block_state.device = components._execution_device + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * components.vae_scale_factor + block_state.width = block_state.width * components.vae_scale_factor - controlnet = unwrap_module(pipeline.controlnet) + controlnet = unwrap_module(components.controlnet) # (1.1) # control_guidance_start/control_guidance_end (align format) - if not isinstance(data.control_guidance_start, list) and isinstance(data.control_guidance_end, list): - data.control_guidance_start = len(data.control_guidance_end) * [data.control_guidance_start] - elif not isinstance(data.control_guidance_end, list) and isinstance(data.control_guidance_start, list): - data.control_guidance_end = len(data.control_guidance_start) * [data.control_guidance_end] - elif not isinstance(data.control_guidance_start, list) and not isinstance(data.control_guidance_end, list): + if not isinstance(block_state.control_guidance_start, list) and isinstance(block_state.control_guidance_end, list): + block_state.control_guidance_start = len(block_state.control_guidance_end) * [block_state.control_guidance_start] + elif not isinstance(block_state.control_guidance_end, list) and isinstance(block_state.control_guidance_start, list): + block_state.control_guidance_end = len(block_state.control_guidance_start) * [block_state.control_guidance_end] + elif not isinstance(block_state.control_guidance_start, list) and not isinstance(block_state.control_guidance_end, list): mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 - data.control_guidance_start, data.control_guidance_end = ( - mult * [data.control_guidance_start], - mult * [data.control_guidance_end], + block_state.control_guidance_start, block_state.control_guidance_end = ( + mult * [block_state.control_guidance_start], + mult * [block_state.control_guidance_end], ) # (1.2) # controlnet_conditioning_scale (align format) - if isinstance(controlnet, MultiControlNetModel) and isinstance(data.controlnet_conditioning_scale, float): - data.controlnet_conditioning_scale = [data.controlnet_conditioning_scale] * len(controlnet.nets) + if isinstance(controlnet, MultiControlNetModel) and isinstance(block_state.controlnet_conditioning_scale, float): + block_state.controlnet_conditioning_scale = [block_state.controlnet_conditioning_scale] * len(controlnet.nets) # (1.3) # global_pool_conditions - data.global_pool_conditions = ( + block_state.global_pool_conditions = ( controlnet.config.global_pool_conditions if isinstance(controlnet, ControlNetModel) else controlnet.nets[0].config.global_pool_conditions ) # (1.4) # guess_mode - data.guess_mode = data.guess_mode or data.global_pool_conditions + block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions # (1.5) # control_image if isinstance(controlnet, ControlNetModel): - data.control_image = self.prepare_control_image( - pipeline, - image=data.control_image, - width=data.width, - height=data.height, - batch_size=data.batch_size * data.num_images_per_prompt, - num_images_per_prompt=data.num_images_per_prompt, - device=data.device, + block_state.control_image = self.prepare_control_image( + components, + image=block_state.control_image, + width=block_state.width, + height=block_state.height, + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_images_per_prompt=block_state.num_images_per_prompt, + device=block_state.device, dtype=controlnet.dtype, - crops_coords=data.crops_coords, + crops_coords=block_state.crops_coords, ) elif isinstance(controlnet, MultiControlNetModel): control_images = [] - for control_image_ in data.control_image: + for control_image_ in block_state.control_image: control_image = self.prepare_control_image( - pipeline, + components, image=control_image_, - width=data.width, - height=data.height, - batch_size=data.batch_size * data.num_images_per_prompt, - num_images_per_prompt=data.num_images_per_prompt, - device=data.device, + width=block_state.width, + height=block_state.height, + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_images_per_prompt=block_state.num_images_per_prompt, + device=block_state.device, dtype=controlnet.dtype, - crops_coords=data.crops_coords, + crops_coords=block_state.crops_coords, ) control_images.append(control_image) - data.control_image = control_images + block_state.control_image = control_images else: assert False # (1.6) # controlnet_keep - data.controlnet_keep = [] - for i in range(len(data.timesteps)): + block_state.controlnet_keep = [] + for i in range(len(block_state.timesteps)): keeps = [ - 1.0 - float(i / len(data.timesteps) < s or (i + 1) / len(data.timesteps) > e) - for s, e in zip(data.control_guidance_start, data.control_guidance_end) + 1.0 - float(i / len(block_state.timesteps) < s or (i + 1) / len(block_state.timesteps) > e) + for s, e in zip(block_state.control_guidance_start, block_state.control_guidance_end) ] - data.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + block_state.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) # (2) Prepare conditional inputs for unet using the guider - data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False - if data.disable_guidance: - pipeline.guider.disable() + block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False + if block_state.disable_guidance: + components.guider.disable() else: - pipeline.guider.enable() + components.guider.enable() # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) - data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0) + block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) + block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) - pipeline.guider.set_input_fields( + components.guider.set_input_fields( prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), add_time_ids=("add_time_ids", "negative_add_time_ids"), pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), @@ -2487,23 +2718,23 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: ) # (5) Denoise loop - with self.progress_bar(total=data.num_inference_steps) as progress_bar: - for i, t in enumerate(data.timesteps): - pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) - guider_data = pipeline.guider.prepare_inputs(data) + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + guider_data = components.guider.prepare_inputs(block_state) - data.scaled_latents = pipeline.scheduler.scale_model_input(data.latents, t) + block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) - if isinstance(data.controlnet_keep[i], list): - data.cond_scale = [c * s for c, s in zip(data.controlnet_conditioning_scale, data.controlnet_keep[i])] + if isinstance(block_state.controlnet_keep[i], list): + block_state.cond_scale = [c * s for c, s in zip(block_state.controlnet_conditioning_scale, block_state.controlnet_keep[i])] else: - data.controlnet_cond_scale = data.controlnet_conditioning_scale - if isinstance(data.controlnet_cond_scale, list): - data.controlnet_cond_scale = data.controlnet_cond_scale[0] - data.cond_scale = data.controlnet_cond_scale * data.controlnet_keep[i] + block_state.controlnet_cond_scale = block_state.controlnet_conditioning_scale + if isinstance(block_state.controlnet_cond_scale, list): + block_state.controlnet_cond_scale = block_state.controlnet_cond_scale[0] + block_state.cond_scale = block_state.controlnet_cond_scale * block_state.controlnet_keep[i] for batch in guider_data: - pipeline.guider.prepare_models(pipeline.unet) + components.guider.prepare_models(components.unet) # Prepare additional conditionings batch.added_cond_kwargs = { @@ -2520,70 +2751,70 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: } # Will always be run atleast once with every guider - if pipeline.guider.is_conditional or not data.guess_mode: - data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet( - data.scaled_latents, + if components.guider.is_conditional or not block_state.guess_mode: + block_state.down_block_res_samples, block_state.mid_block_res_sample = components.controlnet( + block_state.scaled_latents, t, encoder_hidden_states=batch.prompt_embeds, - controlnet_cond=data.control_image, - conditioning_scale=data.cond_scale, - guess_mode=data.guess_mode, + controlnet_cond=block_state.control_image, + conditioning_scale=block_state.cond_scale, + guess_mode=block_state.guess_mode, added_cond_kwargs=batch.controlnet_added_cond_kwargs, return_dict=False, ) - batch.down_block_res_samples = data.down_block_res_samples - batch.mid_block_res_sample = data.mid_block_res_sample + batch.down_block_res_samples = block_state.down_block_res_samples + batch.mid_block_res_sample = block_state.mid_block_res_sample - if pipeline.guider.is_unconditional and data.guess_mode: - batch.down_block_res_samples = [torch.zeros_like(d) for d in data.down_block_res_samples] - batch.mid_block_res_sample = torch.zeros_like(data.mid_block_res_sample) + if components.guider.is_unconditional and block_state.guess_mode: + batch.down_block_res_samples = [torch.zeros_like(d) for d in block_state.down_block_res_samples] + batch.mid_block_res_sample = torch.zeros_like(block_state.mid_block_res_sample) # Prepare for inpainting - if data.num_channels_unet == 9: - data.scaled_latents = torch.cat([data.scaled_latents, data.mask, data.masked_image_latents], dim=1) + if block_state.num_channels_unet == 9: + block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) - batch.noise_pred = pipeline.unet( - data.scaled_latents, + batch.noise_pred = components.unet( + block_state.scaled_latents, t, encoder_hidden_states=batch.prompt_embeds, - timestep_cond=data.timestep_cond, - cross_attention_kwargs=data.cross_attention_kwargs, + timestep_cond=block_state.timestep_cond, + cross_attention_kwargs=block_state.cross_attention_kwargs, added_cond_kwargs=batch.added_cond_kwargs, down_block_additional_residuals=batch.down_block_res_samples, mid_block_additional_residual=batch.mid_block_res_sample, return_dict=False, )[0] - pipeline.guider.cleanup_models(pipeline.unet) + components.guider.cleanup_models(components.unet) # Perform guidance - data.noise_pred, scheduler_step_kwargs = pipeline.guider(guider_data) + block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_data) # Perform scheduler step using the predicted output - data.latents_dtype = data.latents.dtype - data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] + block_state.latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] - if data.latents.dtype != data.latents_dtype: + if block_state.latents.dtype != block_state.latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - data.latents = data.latents.to(data.latents_dtype) + block_state.latents = block_state.latents.to(block_state.latents_dtype) - if data.num_channels_unet == 4 and data.mask is not None and data.image_latents is not None: - data.init_latents_proper = data.image_latents - if i < len(data.timesteps) - 1: - data.noise_timestep = data.timesteps[i + 1] - data.init_latents_proper = pipeline.scheduler.add_noise( - data.init_latents_proper, data.noise, torch.tensor([data.noise_timestep]) + if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: + block_state.init_latents_proper = block_state.image_latents + if i < len(block_state.timesteps) - 1: + block_state.noise_timestep = block_state.timesteps[i + 1] + block_state.init_latents_proper = components.scheduler.add_noise( + block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) ) - data.latents = (1 - data.mask) * data.init_latents_proper + data.mask * data.latents + block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents - if i == len(data.timesteps) - 1 or ((i + 1) > data.num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): + if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): progress_bar.update() - self.add_block_state(state, data) + self.add_block_state(state, block_state) - return pipeline, state + return components, state class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): @@ -2731,31 +2962,32 @@ def intermediates_inputs(self) -> List[str]: def intermediates_outputs(self) -> List[str]: return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - def check_inputs(self, pipeline, data): + @staticmethod + def check_inputs(components, block_state): - num_channels_unet = pipeline.unet.config.in_channels + num_channels_unet = components.unet.config.in_channels if num_channels_unet == 9: # default case for runwayml/stable-diffusion-inpainting - if data.mask is None or data.masked_image_latents is None: + if block_state.mask is None or block_state.masked_image_latents is None: raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") - num_channels_latents = data.latents.shape[1] - num_channels_mask = data.mask.shape[1] - num_channels_masked_image = data.masked_image_latents.shape[1] + num_channels_latents = block_state.latents.shape[1] + num_channels_mask = block_state.mask.shape[1] + num_channels_masked_image = block_state.masked_image_latents.shape[1] if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: raise ValueError( - f"Incorrect configuration settings! The config of `pipeline.unet`: {pipeline.unet.config} expects" - f" {pipeline.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" + f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" - " `pipeline.unet` or your `mask_image` or `image` input." + " `components.unet` or your `mask_image` or `image` input." ) # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image # 1. return image without apply any guidance # 2. add crops_coords and resize_mode to preprocess() + @staticmethod def prepare_control_image( - self, components, image, width, @@ -2785,7 +3017,8 @@ def prepare_control_image( return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components - def prepare_extra_step_kwargs(self, components, generator, eta): + @staticmethod + def prepare_extra_step_kwargs(components, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 @@ -2803,118 +3036,118 @@ def prepare_extra_step_kwargs(self, components, generator, eta): return extra_step_kwargs @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - self.check_inputs(pipeline, data) + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) - data.num_channels_unet = pipeline.unet.config.in_channels + block_state.num_channels_unet = components.unet.config.in_channels # (1) prepare controlnet inputs - data.device = pipeline._execution_device - data.height, data.width = data.latents.shape[-2:] - data.height = data.height * pipeline.vae_scale_factor - data.width = data.width * pipeline.vae_scale_factor + block_state.device = components._execution_device + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * components.vae_scale_factor + block_state.width = block_state.width * components.vae_scale_factor - controlnet = unwrap_module(pipeline.controlnet) + controlnet = unwrap_module(components.controlnet) # (1.1) # control guidance - if not isinstance(data.control_guidance_start, list) and isinstance(data.control_guidance_end, list): - data.control_guidance_start = len(data.control_guidance_end) * [data.control_guidance_start] - elif not isinstance(data.control_guidance_end, list) and isinstance(data.control_guidance_start, list): - data.control_guidance_end = len(data.control_guidance_start) * [data.control_guidance_end] + if not isinstance(block_state.control_guidance_start, list) and isinstance(block_state.control_guidance_end, list): + block_state.control_guidance_start = len(block_state.control_guidance_end) * [block_state.control_guidance_start] + elif not isinstance(block_state.control_guidance_end, list) and isinstance(block_state.control_guidance_start, list): + block_state.control_guidance_end = len(block_state.control_guidance_start) * [block_state.control_guidance_end] # (1.2) # global_pool_conditions & guess_mode - data.global_pool_conditions = controlnet.config.global_pool_conditions - data.guess_mode = data.guess_mode or data.global_pool_conditions + block_state.global_pool_conditions = controlnet.config.global_pool_conditions + block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions # (1.3) # control_type - data.num_control_type = controlnet.config.num_control_type + block_state.num_control_type = controlnet.config.num_control_type # (1.4) # control_type - if not isinstance(data.control_image, list): - data.control_image = [data.control_image] + if not isinstance(block_state.control_image, list): + block_state.control_image = [block_state.control_image] - if not isinstance(data.control_mode, list): - data.control_mode = [data.control_mode] + if not isinstance(block_state.control_mode, list): + block_state.control_mode = [block_state.control_mode] - if len(data.control_image) != len(data.control_mode): + if len(block_state.control_image) != len(block_state.control_mode): raise ValueError("Expected len(control_image) == len(control_type)") - data.control_type = [0 for _ in range(data.num_control_type)] - for control_idx in data.control_mode: - data.control_type[control_idx] = 1 + block_state.control_type = [0 for _ in range(block_state.num_control_type)] + for control_idx in block_state.control_mode: + block_state.control_type[control_idx] = 1 - data.control_type = torch.Tensor(data.control_type) + block_state.control_type = torch.Tensor(block_state.control_type) # (1.5) # prepare control_image - for idx, _ in enumerate(data.control_image): - data.control_image[idx] = self.prepare_control_image( - pipeline, - image=data.control_image[idx], - width=data.width, - height=data.height, - batch_size=data.batch_size * data.num_images_per_prompt, - num_images_per_prompt=data.num_images_per_prompt, - device=data.device, + for idx, _ in enumerate(block_state.control_image): + block_state.control_image[idx] = self.prepare_control_image( + components, + image=block_state.control_image[idx], + width=block_state.width, + height=block_state.height, + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_images_per_prompt=block_state.num_images_per_prompt, + device=block_state.device, dtype=controlnet.dtype, - crops_coords=data.crops_coords, + crops_coords=block_state.crops_coords, ) - data.height, data.width = data.control_image[idx].shape[-2:] + block_state.height, block_state.width = block_state.control_image[idx].shape[-2:] # (1.6) # controlnet_keep - data.controlnet_keep = [] - for i in range(len(data.timesteps)): - data.controlnet_keep.append( + block_state.controlnet_keep = [] + for i in range(len(block_state.timesteps)): + block_state.controlnet_keep.append( 1.0 - - float(i / len(data.timesteps) < data.control_guidance_start or (i + 1) / len(data.timesteps) > data.control_guidance_end) + - float(i / len(block_state.timesteps) < block_state.control_guidance_start or (i + 1) / len(block_state.timesteps) > block_state.control_guidance_end) ) # (2) Prepare conditional inputs for unet using the guider # adding default guider arguments: disable_guidance, guidance_scale, guidance_rescale - data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False - if data.disable_guidance: - pipeline.guider.disable() + block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False + if block_state.disable_guidance: + components.guider.disable() else: - pipeline.guider.enable() + components.guider.enable() - data.control_type = data.control_type.reshape(1, -1).to(data.device, dtype=data.prompt_embeds.dtype) - repeat_by = data.batch_size * data.num_images_per_prompt // data.control_type.shape[0] - data.control_type = data.control_type.repeat_interleave(repeat_by, dim=0) + block_state.control_type = block_state.control_type.reshape(1, -1).to(block_state.device, dtype=block_state.prompt_embeds.dtype) + repeat_by = block_state.batch_size * block_state.num_images_per_prompt // block_state.control_type.shape[0] + block_state.control_type = block_state.control_type.repeat_interleave(repeat_by, dim=0) # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) - data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0) + block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) + block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) - pipeline.guider.set_input_fields( + components.guider.set_input_fields( prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), add_time_ids=("add_time_ids", "negative_add_time_ids"), pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), ) - with self.progress_bar(total=data.num_inference_steps) as progress_bar: - for i, t in enumerate(data.timesteps): - pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) - guider_data = pipeline.guider.prepare_inputs(data) + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + guider_data = components.guider.prepare_inputs(block_state) - data.scaled_latents = pipeline.scheduler.scale_model_input(data.latents, t) + block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) - if isinstance(data.controlnet_keep[i], list): - data.cond_scale = [c * s for c, s in zip(data.controlnet_conditioning_scale, data.controlnet_keep[i])] + if isinstance(block_state.controlnet_keep[i], list): + block_state.cond_scale = [c * s for c, s in zip(block_state.controlnet_conditioning_scale, block_state.controlnet_keep[i])] else: - data.controlnet_cond_scale = data.controlnet_conditioning_scale - if isinstance(data.controlnet_cond_scale, list): - data.controlnet_cond_scale = data.controlnet_cond_scale[0] - data.cond_scale = data.controlnet_cond_scale * data.controlnet_keep[i] + block_state.controlnet_cond_scale = block_state.controlnet_conditioning_scale + if isinstance(block_state.controlnet_cond_scale, list): + block_state.controlnet_cond_scale = block_state.controlnet_cond_scale[0] + block_state.cond_scale = block_state.controlnet_cond_scale * block_state.controlnet_keep[i] for batch in guider_data: - pipeline.guider.prepare_models(pipeline.unet) + components.guider.prepare_models(components.unet) # Prepare additional conditionings batch.added_cond_kwargs = { @@ -2931,70 +3164,70 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: } # Will always be run atleast once with every guider - if pipeline.guider.is_conditional or not data.guess_mode: - data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet( - data.scaled_latents, + if components.guider.is_conditional or not block_state.guess_mode: + block_state.down_block_res_samples, block_state.mid_block_res_sample = components.controlnet( + block_state.scaled_latents, t, encoder_hidden_states=batch.prompt_embeds, - controlnet_cond=data.control_image, - control_type=data.control_type, - control_type_idx=data.control_mode, - conditioning_scale=data.cond_scale, - guess_mode=data.guess_mode, + controlnet_cond=block_state.control_image, + control_type=block_state.control_type, + control_type_idx=block_state.control_mode, + conditioning_scale=block_state.cond_scale, + guess_mode=block_state.guess_mode, added_cond_kwargs=batch.controlnet_added_cond_kwargs, return_dict=False, ) - batch.down_block_res_samples = data.down_block_res_samples - batch.mid_block_res_sample = data.mid_block_res_sample + batch.down_block_res_samples = block_state.down_block_res_samples + batch.mid_block_res_sample = block_state.mid_block_res_sample - if pipeline.guider.is_unconditional and data.guess_mode: - batch.down_block_res_samples = [torch.zeros_like(d) for d in data.down_block_res_samples] - batch.mid_block_res_sample = torch.zeros_like(data.mid_block_res_sample) + if components.guider.is_unconditional and block_state.guess_mode: + batch.down_block_res_samples = [torch.zeros_like(d) for d in block_state.down_block_res_samples] + batch.mid_block_res_sample = torch.zeros_like(block_state.mid_block_res_sample) - if data.num_channels_unet == 9: - data.scaled_latents = torch.cat([data.scaled_latents, data.mask, data.masked_image_latents], dim=1) + if block_state.num_channels_unet == 9: + block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) - batch.noise_pred = pipeline.unet( - data.scaled_latents, + batch.noise_pred = components.unet( + block_state.scaled_latents, t, encoder_hidden_states=batch.prompt_embeds, - timestep_cond=data.timestep_cond, - cross_attention_kwargs=data.cross_attention_kwargs, + timestep_cond=block_state.timestep_cond, + cross_attention_kwargs=block_state.cross_attention_kwargs, added_cond_kwargs=batch.added_cond_kwargs, down_block_additional_residuals=batch.down_block_res_samples, mid_block_additional_residual=batch.mid_block_res_sample, return_dict=False, )[0] - pipeline.guider.cleanup_models(pipeline.unet) + components.guider.cleanup_models(components.unet) # Perform guidance - data.noise_pred, scheduler_step_kwargs = pipeline.guider(guider_data) + block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_data) # Perform scheduler step using the predicted output - data.latents_dtype = data.latents.dtype - data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] + block_state.latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] - if data.latents.dtype != data.latents_dtype: + if block_state.latents.dtype != block_state.latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - data.latents = data.latents.to(data.latents_dtype) - - if data.num_channels_unet == 9 and data.mask is not None and data.image_latents is not None: - data.init_latents_proper = data.image_latents - if i < len(data.timesteps) - 1: - data.noise_timestep = data.timesteps[i + 1] - data.init_latents_proper = pipeline.scheduler.add_noise( - data.init_latents_proper, data.noise, torch.tensor([data.noise_timestep]) + block_state.latents = block_state.latents.to(block_state.latents_dtype) + + if block_state.num_channels_unet == 9 and block_state.mask is not None and block_state.image_latents is not None: + block_state.init_latents_proper = block_state.image_latents + if i < len(block_state.timesteps) - 1: + block_state.noise_timestep = block_state.timesteps[i + 1] + block_state.init_latents_proper = components.scheduler.add_noise( + block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) ) - data.latents = (1 - data.mask) * data.init_latents_proper + data.mask * data.latents + block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents - if i == len(data.timesteps) - 1 or ((i + 1) > data.num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): + if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): progress_bar.update() - self.add_block_state(state, data) + self.add_block_state(state, block_state) - return pipeline, state + return components, state class StableDiffusionXLDecodeLatentsStep(PipelineBlock): @@ -3031,7 +3264,8 @@ def intermediates_outputs(self) -> List[str]: return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array")] # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae with self -> components - def upcast_vae(self, components): + @staticmethod + def upcast_vae(components): dtype = components.vae.dtype components.vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = isinstance( @@ -3049,57 +3283,57 @@ def upcast_vae(self, components): components.vae.decoder.mid_block.to(dtype) @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) - if not data.output_type == "latent": + if not block_state.output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 - data.needs_upcasting = pipeline.vae.dtype == torch.float16 and pipeline.vae.config.force_upcast + block_state.needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast - if data.needs_upcasting: - self.upcast_vae(pipeline) - data.latents = data.latents.to(next(iter(pipeline.vae.post_quant_conv.parameters())).dtype) - elif data.latents.dtype != pipeline.vae.dtype: + if block_state.needs_upcasting: + self.upcast_vae(components) + block_state.latents = block_state.latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype) + elif block_state.latents.dtype != components.vae.dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - pipeline.vae = pipeline.vae.to(data.latents.dtype) + components.vae = components.vae.to(block_state.latents.dtype) # unscale/denormalize the latents # denormalize with the mean and std if available and not None - data.has_latents_mean = ( - hasattr(pipeline.vae.config, "latents_mean") and pipeline.vae.config.latents_mean is not None + block_state.has_latents_mean = ( + hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None ) - data.has_latents_std = ( - hasattr(pipeline.vae.config, "latents_std") and pipeline.vae.config.latents_std is not None + block_state.has_latents_std = ( + hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None ) - if data.has_latents_mean and data.has_latents_std: - data.latents_mean = ( - torch.tensor(pipeline.vae.config.latents_mean).view(1, 4, 1, 1).to(data.latents.device, data.latents.dtype) + if block_state.has_latents_mean and block_state.has_latents_std: + block_state.latents_mean = ( + torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1).to(block_state.latents.device, block_state.latents.dtype) ) - data.latents_std = ( - torch.tensor(pipeline.vae.config.latents_std).view(1, 4, 1, 1).to(data.latents.device, data.latents.dtype) + block_state.latents_std = ( + torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(block_state.latents.device, block_state.latents.dtype) ) - data.latents = data.latents * data.latents_std / pipeline.vae.config.scaling_factor + data.latents_mean + block_state.latents = block_state.latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean else: - data.latents = data.latents / pipeline.vae.config.scaling_factor + block_state.latents = block_state.latents / components.vae.config.scaling_factor - data.images = pipeline.vae.decode(data.latents, return_dict=False)[0] + block_state.images = components.vae.decode(block_state.latents, return_dict=False)[0] # cast back to fp16 if needed - if data.needs_upcasting: - pipeline.vae.to(dtype=torch.float16) + if block_state.needs_upcasting: + components.vae.to(dtype=torch.float16) else: - data.images = data.latents + block_state.images = block_state.latents # apply watermark if available - if hasattr(pipeline, "watermark") and pipeline.watermark is not None: - data.images = pipeline.watermark.apply_watermark(data.images) + if hasattr(components, "watermark") and components.watermark is not None: + block_state.images = components.watermark.apply_watermark(block_state.images) - data.images = pipeline.image_processor.postprocess(data.images, output_type=data.output_type) + block_state.images = components.image_processor.postprocess(block_state.images, output_type=block_state.output_type) - self.add_block_state(state, data) + self.add_block_state(state, block_state) - return pipeline, state + return components, state class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock): @@ -3130,15 +3364,15 @@ def intermediates_outputs(self) -> List[str]: return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images with the mask overlayed")] @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) - if data.padding_mask_crop is not None and data.crops_coords is not None: - data.images = [pipeline.image_processor.apply_overlay(data.mask_image, data.image, i, data.crops_coords) for i in data.images] + if block_state.padding_mask_crop is not None and block_state.crops_coords is not None: + block_state.images = [components.image_processor.apply_overlay(block_state.mask_image, block_state.image, i, block_state.crops_coords) for i in block_state.images] - self.add_block_state(state, data) + self.add_block_state(state, block_state) - return pipeline, state + return components, state class StableDiffusionXLOutputStep(PipelineBlock): @@ -3162,15 +3396,15 @@ def intermediates_outputs(self) -> List[str]: @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) - if not data.return_dict: - data.images = (data.images,) + if not block_state.return_dict: + block_state.images = (block_state.images,) else: - data.images = StableDiffusionXLPipelineOutput(images=data.images) - self.add_block_state(state, data) - return pipeline, state + block_state.images = StableDiffusionXLPipelineOutput(images=block_state.images) + self.add_block_state(state, block_state) + return components, state # Encode @@ -3400,50 +3634,6 @@ def description(self): } -# YiYi Notes: model specific components: -## (1) it should inherit from ModularLoader -## (2) acts like a container that holds components and configs -## (3) define default config (related to components), e.g. default_sample_size, vae_scale_factor, num_channels_unet, num_channels_latents -## (4) inherit from model-specic loader class (e.g. StableDiffusionXLLoraLoaderMixin) -## (5) how to use together with Components_manager? -class StableDiffusionXLModularLoader( - ModularLoader, - StableDiffusionMixin, - TextualInversionLoaderMixin, - StableDiffusionXLLoraLoaderMixin, - ModularIPAdapterMixin, -): - @property - def default_sample_size(self): - default_sample_size = 128 - if hasattr(self, "unet") and self.unet is not None: - default_sample_size = self.unet.config.sample_size - return default_sample_size - - @property - def vae_scale_factor(self): - vae_scale_factor = 8 - if hasattr(self, "vae") and self.vae is not None: - vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - return vae_scale_factor - - @property - def num_channels_unet(self): - num_channels_unet = 4 - if hasattr(self, "unet") and self.unet is not None: - num_channels_unet = self.unet.config.in_channels - return num_channels_unet - - @property - def num_channels_latents(self): - num_channels_latents = 4 - if hasattr(self, "vae") and self.vae is not None: - num_channels_latents = self.vae.config.latent_channels - return num_channels_latents - - - - # YiYi Notes: not used yet, maintain a list of schema that can be used across all pipeline blocks SDXL_INPUTS_SCHEMA = { From efd70b783871aa7b3e02bd8252afbc8e45eeb314 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 3 May 2025 20:22:05 +0200 Subject: [PATCH 04/38] seperate controlnet step into input + denoise --- .../pipeline_stable_diffusion_xl_modular.py | 466 +++++++++++------- 1 file changed, 299 insertions(+), 167 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 81808540ee67..ea774283437a 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -2395,27 +2395,20 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt return components, state -class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): +class StableDiffusionXLControlNetInputStep(PipelineBlock): model_name = "stable-diffusion-xl" @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config"), - ComponentSpec("scheduler", EulerDiscreteScheduler), - ComponentSpec("unet", UNet2DConditionModel), ComponentSpec("controlnet", ControlNetModel), ComponentSpec("control_image_processor", VaeImageProcessor, config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), default_creation_method="from_config"), ] @property def description(self) -> str: - return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" + return "step that prepare inputs for controlnet" @property def inputs(self) -> List[Tuple[str, Any]]: @@ -2426,9 +2419,6 @@ def inputs(self) -> List[Tuple[str, Any]]: InputParam("controlnet_conditioning_scale", default=1.0), InputParam("guess_mode", default=False), InputParam("num_images_per_prompt", default=1), - InputParam("cross_attention_kwargs"), - InputParam("generator"), - InputParam("eta", default=0.0), ] @property @@ -2452,110 +2442,25 @@ def intermediates_inputs(self) -> List[str]: type_hint=torch.Tensor, description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." ), - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "negative_prompt_embeds", - type_hint=Optional[torch.Tensor], - description="The negative prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "add_time_ids", - required=True, - type_hint=torch.Tensor, - description="The time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." - ), - InputParam( - "negative_add_time_ids", - type_hint=Optional[torch.Tensor], - description="The negative time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." - ), - InputParam( - "pooled_prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The pooled prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "negative_pooled_prompt_embeds", - type_hint=Optional[torch.Tensor], - description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "timestep_cond", - type_hint=Optional[torch.Tensor], - description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" - ), - InputParam( - "mask", - type_hint=Optional[torch.Tensor], - description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "masked_image_latents", - type_hint=Optional[torch.Tensor], - description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "noise", - type_hint=Optional[torch.Tensor], - description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." - ), - InputParam( - "image_latents", - type_hint=Optional[torch.Tensor], - description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." - ), InputParam( "crops_coords", type_hint=Optional[Tuple[int]], description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." ), - InputParam( - "ip_adapter_embeds", - type_hint=Optional[torch.Tensor], - description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." - ), - InputParam( - "negative_ip_adapter_embeds", - type_hint=Optional[torch.Tensor], - description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." - ), ] @property def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + return [ + OutputParam("control_image", type_hint=torch.Tensor, description="The processed control image"), + OutputParam("control_guidance_start", type_hint=List[float], description="The controlnet guidance start values"), + OutputParam("control_guidance_end", type_hint=List[float], description="The controlnet guidance end values"), + OutputParam("controlnet_conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values"), + OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"), + OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"), + ] - @staticmethod - def check_inputs(components, block_state): - num_channels_unet = components.unet.config.in_channels - if num_channels_unet == 9: - # default case for runwayml/stable-diffusion-inpainting - if block_state.mask is None or block_state.masked_image_latents is None: - raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") - num_channels_latents = block_state.latents.shape[1] - num_channels_mask = block_state.mask.shape[1] - num_channels_masked_image = block_state.masked_image_latents.shape[1] - if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: - raise ValueError( - f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" - f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" - f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" - " `components.unet` or your `mask_image` or `image` input." - ) # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image # 1. return image without apply any guidance @@ -2588,33 +2493,12 @@ def prepare_control_image( image = image.to(device=device, dtype=dtype) return image - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components - @staticmethod - def prepare_extra_step_kwargs(components, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(components.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(components.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs @torch.no_grad() def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - self.check_inputs(components, block_state) - - block_state.num_channels_unet = components.unet.config.in_channels # (1) prepare controlnet inputs block_state.device = components._execution_device @@ -2699,17 +2583,243 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt ] block_state.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) - # (2) Prepare conditional inputs for unet using the guider + + + self.add_block_state(state, block_state) + + return components, state + +class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("unet", UNet2DConditionModel), + ComponentSpec("controlnet", ControlNetModel), + ] + + @property + def description(self) -> str: + return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam("num_images_per_prompt", default=1), + InputParam("cross_attention_kwargs"), + InputParam("generator"), + InputParam("eta", default=0.0), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "control_image", + required=True, + type_hint=torch.Tensor, + description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "control_guidance_start", + required=True, + type_hint=float, + description="The control guidance start value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "control_guidance_end", + required=True, + type_hint=float, + description="The control guidance end value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "controlnet_conditioning_scale", + required=True, + type_hint=float, + description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "guess_mode", + required=True, + type_hint=bool, + description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "controlnet_keep", + required=True, + type_hint=List[float], + description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + ), + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="The prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." + ), + InputParam( + "negative_prompt_embeds", + type_hint=Optional[torch.Tensor], + description="The negative prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." + ), + InputParam( + "add_time_ids", + required=True, + type_hint=torch.Tensor, + description="The time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." + ), + InputParam( + "negative_add_time_ids", + type_hint=Optional[torch.Tensor], + description="The negative time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." + ), + InputParam( + "pooled_prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="The pooled prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." + ), + InputParam( + "negative_pooled_prompt_embeds", + type_hint=Optional[torch.Tensor], + description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." + ), + InputParam( + "timestep_cond", + type_hint=Optional[torch.Tensor], + description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" + ), + InputParam( + "mask", + type_hint=Optional[torch.Tensor], + description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." + ), + InputParam( + "masked_image_latents", + type_hint=Optional[torch.Tensor], + description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." + ), + InputParam( + "noise", + type_hint=Optional[torch.Tensor], + description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." + ), + InputParam( + "image_latents", + type_hint=Optional[torch.Tensor], + description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." + ), + InputParam( + "crops_coords", + type_hint=Optional[Tuple[int]], + description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." + ), + InputParam( + "ip_adapter_embeds", + type_hint=Optional[torch.Tensor], + description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." + ), + InputParam( + "negative_ip_adapter_embeds", + type_hint=Optional[torch.Tensor], + description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + @staticmethod + def check_inputs(components, block_state): + + num_channels_unet = components.unet.config.in_channels + if num_channels_unet == 9: + # default case for runwayml/stable-diffusion-inpainting + if block_state.mask is None or block_state.masked_image_latents is None: + raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") + num_channels_latents = block_state.latents.shape[1] + num_channels_mask = block_state.mask.shape[1] + num_channels_masked_image = block_state.masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: + raise ValueError( + f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" + f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `components.unet` or your `mask_image` or `image` input." + ) + + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components + @staticmethod + def prepare_extra_step_kwargs(components, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(components.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(components.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + block_state.device = components._execution_device + + # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) + block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) + + # (1) setup guider + # disable for LCMs block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False if block_state.disable_guidance: components.guider.disable() else: components.guider.enable() - - # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) - block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) - components.guider.set_input_fields( prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), add_time_ids=("add_time_ids", "negative_add_time_ids"), @@ -2720,11 +2830,16 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt # (5) Denoise loop with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: for i, t in enumerate(block_state.timesteps): - components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) - guider_data = components.guider.prepare_inputs(block_state) + # prepare latent input for unet block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) + # adjust latent input for inpainting + block_state.num_channels_unet = components.unet.config.in_channels + if block_state.num_channels_unet == 9: + block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) + + # cond_scale (controlnet input) if isinstance(block_state.controlnet_keep[i], list): block_state.cond_scale = [c * s for c, s in zip(block_state.controlnet_conditioning_scale, block_state.controlnet_keep[i])] else: @@ -2733,62 +2848,69 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt block_state.controlnet_cond_scale = block_state.controlnet_cond_scale[0] block_state.cond_scale = block_state.controlnet_cond_scale * block_state.controlnet_keep[i] - for batch in guider_data: + # default controlnet output/unet input for guess mode + conditional path + block_state.down_block_res_samples_zeros = None + block_state.mid_block_res_sample_zeros = None + + # guided denoiser step + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + guider_state = components.guider.prepare_inputs(block_state) + + for guider_state_batch in guider_state: components.guider.prepare_models(components.unet) # Prepare additional conditionings - batch.added_cond_kwargs = { - "text_embeds": batch.pooled_prompt_embeds, - "time_ids": batch.add_time_ids, + guider_state_batch.added_cond_kwargs = { + "text_embeds": guider_state_batch.pooled_prompt_embeds, + "time_ids": guider_state_batch.add_time_ids, } - if batch.ip_adapter_embeds is not None: - batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds + if guider_state_batch.ip_adapter_embeds is not None: + guider_state_batch.added_cond_kwargs["image_embeds"] = guider_state_batch.ip_adapter_embeds # Prepare controlnet additional conditionings - batch.controlnet_added_cond_kwargs = { - "text_embeds": batch.pooled_prompt_embeds, - "time_ids": batch.add_time_ids, + guider_state_batch.controlnet_added_cond_kwargs = { + "text_embeds": guider_state_batch.pooled_prompt_embeds, + "time_ids": guider_state_batch.add_time_ids, } - # Will always be run atleast once with every guider - if components.guider.is_conditional or not block_state.guess_mode: - block_state.down_block_res_samples, block_state.mid_block_res_sample = components.controlnet( + if block_state.guess_mode and not components.guider.is_conditional: + # guider always run uncond batch first, so these tensors should be set already + guider_state_batch.down_block_res_samples = block_state.down_block_res_samples_zeros + guider_state_batch.mid_block_res_sample = block_state.mid_block_res_sample_zeros + else: + guider_state_batch.down_block_res_samples, guider_state_batch.mid_block_res_sample = components.controlnet( block_state.scaled_latents, t, - encoder_hidden_states=batch.prompt_embeds, + encoder_hidden_states=guider_state_batch.prompt_embeds, controlnet_cond=block_state.control_image, conditioning_scale=block_state.cond_scale, guess_mode=block_state.guess_mode, - added_cond_kwargs=batch.controlnet_added_cond_kwargs, + added_cond_kwargs=guider_state_batch.controlnet_added_cond_kwargs, return_dict=False, ) - batch.down_block_res_samples = block_state.down_block_res_samples - batch.mid_block_res_sample = block_state.mid_block_res_sample + if block_state.down_block_res_samples_zeros is None: + block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in guider_state_batch.down_block_res_samples] + if block_state.mid_block_res_sample_zeros is None: + block_state.mid_block_res_sample_zeros = torch.zeros_like(guider_state_batch.mid_block_res_sample) - if components.guider.is_unconditional and block_state.guess_mode: - batch.down_block_res_samples = [torch.zeros_like(d) for d in block_state.down_block_res_samples] - batch.mid_block_res_sample = torch.zeros_like(block_state.mid_block_res_sample) - # Prepare for inpainting - if block_state.num_channels_unet == 9: - block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) - - batch.noise_pred = components.unet( + + guider_state_batch.noise_pred = components.unet( block_state.scaled_latents, t, - encoder_hidden_states=batch.prompt_embeds, + encoder_hidden_states=guider_state_batch.prompt_embeds, timestep_cond=block_state.timestep_cond, cross_attention_kwargs=block_state.cross_attention_kwargs, - added_cond_kwargs=batch.added_cond_kwargs, - down_block_additional_residuals=batch.down_block_res_samples, - mid_block_additional_residual=batch.mid_block_res_sample, + added_cond_kwargs=guider_state_batch.added_cond_kwargs, + down_block_additional_residuals=guider_state_batch.down_block_res_samples, + mid_block_additional_residual=guider_state_batch.mid_block_res_sample, return_dict=False, )[0] components.guider.cleanup_models(components.unet) # Perform guidance - block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_data) + block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_state) # Perform scheduler step using the predicted output block_state.latents_dtype = block_state.latents.dtype @@ -2799,6 +2921,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 block_state.latents = block_state.latents.to(block_state.latents_dtype) + # adjust latent for inpainting if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: block_state.init_latents_proper = block_state.image_latents if i < len(block_state.timesteps) - 1: @@ -3463,6 +3586,16 @@ def description(self): " - `StableDiffusionXLInpaintPrepareLatentsStep` is used to prepare the latents\n" + \ " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning" +class StableDiffusionXLControlNetStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLControlNetInputStep, StableDiffusionXLControlNetDenoiseStep] + block_names = ["prepare_input", "denoise"] + + @property + def description(self): + return "Controlnet step that denoise the latents.\n" + \ + "This is a sequential pipeline blocks:\n" + \ + " - `StableDiffusionXLControlNetInputStep` is used to prepare the inputs for the denoise step.\n" + \ + " - `StableDiffusionXLControlNetDenoiseStep` is used to denoise the latents." class StableDiffusionXLAutoBeforeDenoiseStep(AutoPipelineBlocks): block_classes = [StableDiffusionXLInpaintBeforeDenoiseStep, StableDiffusionXLImg2ImgBeforeDenoiseStep, StableDiffusionXLBeforeDenoiseStep] @@ -3477,10 +3610,9 @@ def description(self): " - `StableDiffusionXLImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n" + \ " - `StableDiffusionXLBeforeDenoiseStep` (text2img) is used when both `image_latents` and `mask` are not provided." - # Denoise class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLControlNetUnionDenoiseStep, StableDiffusionXLControlNetDenoiseStep, StableDiffusionXLDenoiseStep] + block_classes = [StableDiffusionXLControlNetUnionDenoiseStep, StableDiffusionXLControlNetStep, StableDiffusionXLDenoiseStep] block_names = ["controlnet_union", "controlnet", "unet"] block_trigger_inputs = ["control_mode", "control_image", None] @@ -3489,7 +3621,7 @@ def description(self): return "Denoise step that denoise the latents.\n" + \ "This is an auto pipeline block that works for controlnet, controlnet_union and no controlnet.\n" + \ " - `StableDiffusionXLControlNetUnionDenoiseStep` (controlnet_union) is used when both `control_mode` and `control_image` are provided.\n" + \ - " - `StableDiffusionXLControlNetDenoiseStep` (controlnet) is used when `control_image` is provided.\n" + \ + " - `StableDiffusionXLControlStep` (controlnet) is used when `control_image` is provided.\n" + \ " - `StableDiffusionXLDenoiseStep` (unet only) is used when both `control_mode` and `control_image` are not provided." # After denoise @@ -3597,7 +3729,7 @@ def description(self): ]) CONTROLNET_BLOCKS = OrderedDict([ - ("denoise", StableDiffusionXLControlNetDenoiseStep), + ("denoise", StableDiffusionXLControlNetStep), ]) CONTROLNET_UNION_BLOCKS = OrderedDict([ From 43ac1ff7e78ffdf8fa91932769236a7995ac482e Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 4 May 2025 22:17:25 +0200 Subject: [PATCH 05/38] refactor controlnet union --- .../pipeline_stable_diffusion_xl_modular.py | 426 ++++++++++++------ 1 file changed, 284 insertions(+), 142 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index ea774283437a..5ebdd383ccbb 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -2613,12 +2613,6 @@ def description(self) -> str: @property def inputs(self) -> List[Tuple[str, Any]]: return [ - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." - ), InputParam("num_images_per_prompt", default=1), InputParam("cross_attention_kwargs"), InputParam("generator"), @@ -2755,6 +2749,12 @@ def intermediates_inputs(self) -> List[str]: type_hint=Optional[torch.Tensor], description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." + ), ] @property @@ -2940,25 +2940,198 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt return components, state +class StableDiffusionXLControlNetUnionInputStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("controlnet", ControlNetUnionModel), + ComponentSpec("control_image_processor", VaeImageProcessor, config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), default_creation_method="from_config"), + ] + + @property + def description(self) -> str: + return "step that prepares inputs for the ControlNetUnion model" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("control_image", required=True), + InputParam("control_mode", default=[0]), + InputParam("control_guidance_start", default=0.0), + InputParam("control_guidance_end", default=1.0), + InputParam("controlnet_conditioning_scale", default=1.0), + InputParam("guess_mode", default=False), + InputParam("num_images_per_prompt", default=1), + ] + + @property + def intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + ), + InputParam( + "dtype", + required=True, + type_hint=torch.dtype, + description="The dtype of model tensor inputs. Can be generated in input step." + ), + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + "crops_coords", + type_hint=Optional[Tuple[int]], + description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [ + OutputParam("control_image", type_hint=List[torch.Tensor], description="The processed control images"), + OutputParam("control_mode", type_hint=List[int], description="The control mode indices"), + OutputParam("control_type", type_hint=torch.Tensor, description="The control type tensor that specifies which control type is active"), + OutputParam("control_guidance_start", type_hint=float, description="The controlnet guidance start value"), + OutputParam("control_guidance_end", type_hint=float, description="The controlnet guidance end value"), + OutputParam("controlnet_conditioning_scale", type_hint=float, description="The controlnet conditioning scale value"), + OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"), + OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"), + ] + + # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image + # 1. return image without apply any guidance + # 2. add crops_coords and resize_mode to preprocess() + @staticmethod + def prepare_control_image( + components, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + crops_coords=None, + ): + if crops_coords is not None: + image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) + else: + image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + + image_batch_size = image.shape[0] + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + image = image.to(device=device, dtype=dtype) + return image + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + + block_state = self.get_block_state(state) + + controlnet = unwrap_module(components.controlnet) + + device = block_state.device or components._execution_device + dtype = block_state.dtype or components.controlnet.dtype + + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * components.vae_scale_factor + block_state.width = block_state.width * components.vae_scale_factor + + + # control_guidance_start/control_guidance_end (align format) + if not isinstance(block_state.control_guidance_start, list) and isinstance(block_state.control_guidance_end, list): + block_state.control_guidance_start = len(block_state.control_guidance_end) * [block_state.control_guidance_start] + elif not isinstance(block_state.control_guidance_end, list) and isinstance(block_state.control_guidance_start, list): + block_state.control_guidance_end = len(block_state.control_guidance_start) * [block_state.control_guidance_end] + + # guess_mode + block_state.global_pool_conditions = controlnet.config.global_pool_conditions + block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions + + + if not isinstance(block_state.control_image, list): + block_state.control_image = [block_state.control_image] + + if not isinstance(block_state.control_mode, list): + block_state.control_mode = [block_state.control_mode] + + if len(block_state.control_image) != len(block_state.control_mode): + raise ValueError("Expected len(control_image) == len(control_type)") + + # control_type + block_state.num_control_type = controlnet.config.num_control_type + block_state.control_type = [0 for _ in range(block_state.num_control_type)] + for control_idx in block_state.control_mode: + block_state.control_type[control_idx] = 1 + block_state.control_type = torch.Tensor(block_state.control_type) + + block_state.control_type = block_state.control_type.reshape(1, -1).to(device, dtype=block_state.dtype) + repeat_by = block_state.batch_size * block_state.num_images_per_prompt // block_state.control_type.shape[0] + block_state.control_type = block_state.control_type.repeat_interleave(repeat_by, dim=0) + + # prepare control_image + for idx, _ in enumerate(block_state.control_image): + block_state.control_image[idx] = self.prepare_control_image( + components, + image=block_state.control_image[idx], + width=block_state.width, + height=block_state.height, + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_images_per_prompt=block_state.num_images_per_prompt, + device=device, + dtype=dtype, + crops_coords=block_state.crops_coords, + ) + block_state.height, block_state.width = block_state.control_image[idx].shape[-2:] + + # controlnet_keep + block_state.controlnet_keep = [] + for i in range(len(block_state.timesteps)): + block_state.controlnet_keep.append( + 1.0 + - float(i / len(block_state.timesteps) < block_state.control_guidance_start or (i + 1) / len(block_state.timesteps) > block_state.control_guidance_end) + ) + + + self.add_block_state(state, block_state) + + return components, state + class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): model_name = "stable-diffusion-xl" @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec("unet", UNet2DConditionModel), - ComponentSpec("controlnet", ControlNetUnionModel), - ComponentSpec("scheduler", EulerDiscreteScheduler), ComponentSpec( "guider", ClassifierFreeGuidance, config=FrozenDict({"guidance_scale": 7.5}), default_creation_method="from_config"), - ComponentSpec( - "control_image_processor", - VaeImageProcessor, - config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), - default_creation_method="from_config"), + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("unet", UNet2DConditionModel), + ComponentSpec("controlnet", ControlNetUnionModel), ] @property @@ -2967,12 +3140,6 @@ def description(self) -> str: @property def inputs(self) -> List[Tuple[str, Any]]: return [ - InputParam("control_image", required=True), - InputParam("control_guidance_start", default=0.0), - InputParam("control_guidance_end", default=1.0), - InputParam("control_mode", required=True), - InputParam("controlnet_conditioning_scale", default=1.0), - InputParam("guess_mode", default=False), InputParam("num_images_per_prompt", default=1), InputParam("cross_attention_kwargs"), InputParam("generator"), @@ -2983,15 +3150,75 @@ def inputs(self) -> List[Tuple[str, Any]]: def intermediates_inputs(self) -> List[str]: return [ InputParam( - "latents", + "control_image", + required=True, + type_hint=List[torch.Tensor], + description="The control images to use for conditioning. Can be generated in prepare controlnet inputs step." + ), + InputParam( + "control_mode", + required=True, + type_hint=List[int], + description="The control mode indices. Can be generated in prepare controlnet inputs step." + ), + InputParam( + "control_type", required=True, type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + description="The control type tensor that specifies which control type is active. Can be generated in prepare controlnet inputs step." ), InputParam( - "batch_size", + "num_control_type", required=True, type_hint=int, + description="The number of control types available. Can be generated in prepare controlnet inputs step." + ), + InputParam( + "control_guidance_start", + required=True, + type_hint=float, + description="The control guidance start value. Can be generated in prepare controlnet inputs step." + ), + InputParam( + "control_guidance_end", + required=True, + type_hint=float, + description="The control guidance end value. Can be generated in prepare controlnet inputs step." + ), + InputParam( + "controlnet_conditioning_scale", + required=True, + type_hint=float, + description="The controlnet conditioning scale value. Can be generated in prepare controlnet inputs step." + ), + InputParam( + "guess_mode", + required=True, + type_hint=bool, + description="Whether guess mode is used. Can be generated in prepare controlnet inputs step." + ), + InputParam( + "global_pool_conditions", + required=True, + type_hint=bool, + description="Whether global pool conditions are used. Can be generated in prepare controlnet inputs step." + ), + InputParam( + "controlnet_keep", + required=True, + type_hint=List[float], + description="The controlnet keep values. Can be generated in prepare controlnet inputs step." + ), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + InputParam( + "batch_size", + required=True, + type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." ), InputParam( @@ -3045,23 +3272,23 @@ def intermediates_inputs(self) -> List[str]: description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." ), InputParam( - "mask", - type_hint=Optional[torch.Tensor], + "mask", + type_hint=Optional[torch.Tensor], description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." ), InputParam( - "masked_image_latents", - type_hint=Optional[torch.Tensor], + "masked_image_latents", + type_hint=Optional[torch.Tensor], description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." ), InputParam( - "noise", - type_hint=Optional[torch.Tensor], + "noise", + type_hint=Optional[torch.Tensor], description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." ), InputParam( - "image_latents", - type_hint=Optional[torch.Tensor], + "image_latents", + type_hint=Optional[torch.Tensor], description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." ), InputParam( @@ -3070,19 +3297,19 @@ def intermediates_inputs(self) -> List[str]: description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode or prepare_latent step." ), InputParam( - "ip_adapter_embeds", - type_hint=Optional[torch.Tensor], + "ip_adapter_embeds", + type_hint=Optional[torch.Tensor], description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." ), InputParam( - "negative_ip_adapter_embeds", - type_hint=Optional[torch.Tensor], + "negative_ip_adapter_embeds", + type_hint=Optional[torch.Tensor], description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." ), ] @property - def intermediates_outputs(self) -> List[str]: + def intermediates_outputs(self) -> List[OutputParam]: return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] @staticmethod @@ -3105,39 +3332,7 @@ def check_inputs(components, block_state): " `components.unet` or your `mask_image` or `image` input." ) - - # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image - # 1. return image without apply any guidance - # 2. add crops_coords and resize_mode to preprocess() - @staticmethod - def prepare_control_image( - components, - image, - width, - height, - batch_size, - num_images_per_prompt, - device, - dtype, - crops_coords=None, - ): - if crops_coords is not None: - image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) - else: - image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) - image_batch_size = image.shape[0] - - if image_batch_size == 1: - repeat_by = batch_size - else: - # image batch size is the same as prompt batch size - repeat_by = num_images_per_prompt - image = image.repeat_interleave(repeat_by, dim=0) - - image = image.to(device=device, dtype=dtype) - - return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components @staticmethod @@ -3164,85 +3359,20 @@ def __call__(self, components, state: PipelineState) -> PipelineState: self.check_inputs(components, block_state) block_state.num_channels_unet = components.unet.config.in_channels - - # (1) prepare controlnet inputs block_state.device = components._execution_device - block_state.height, block_state.width = block_state.latents.shape[-2:] - block_state.height = block_state.height * components.vae_scale_factor - block_state.width = block_state.width * components.vae_scale_factor - - controlnet = unwrap_module(components.controlnet) - - # (1.1) - # control guidance - if not isinstance(block_state.control_guidance_start, list) and isinstance(block_state.control_guidance_end, list): - block_state.control_guidance_start = len(block_state.control_guidance_end) * [block_state.control_guidance_start] - elif not isinstance(block_state.control_guidance_end, list) and isinstance(block_state.control_guidance_start, list): - block_state.control_guidance_end = len(block_state.control_guidance_start) * [block_state.control_guidance_end] - - # (1.2) - # global_pool_conditions & guess_mode - block_state.global_pool_conditions = controlnet.config.global_pool_conditions - block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions - - # (1.3) - # control_type - block_state.num_control_type = controlnet.config.num_control_type - - # (1.4) - # control_type - if not isinstance(block_state.control_image, list): - block_state.control_image = [block_state.control_image] - - if not isinstance(block_state.control_mode, list): - block_state.control_mode = [block_state.control_mode] - - if len(block_state.control_image) != len(block_state.control_mode): - raise ValueError("Expected len(control_image) == len(control_type)") - - block_state.control_type = [0 for _ in range(block_state.num_control_type)] - for control_idx in block_state.control_mode: - block_state.control_type[control_idx] = 1 - - block_state.control_type = torch.Tensor(block_state.control_type) - # (1.5) - # prepare control_image - for idx, _ in enumerate(block_state.control_image): - block_state.control_image[idx] = self.prepare_control_image( - components, - image=block_state.control_image[idx], - width=block_state.width, - height=block_state.height, - batch_size=block_state.batch_size * block_state.num_images_per_prompt, - num_images_per_prompt=block_state.num_images_per_prompt, - device=block_state.device, - dtype=controlnet.dtype, - crops_coords=block_state.crops_coords, - ) - block_state.height, block_state.width = block_state.control_image[idx].shape[-2:] - - # (1.6) - # controlnet_keep - block_state.controlnet_keep = [] - for i in range(len(block_state.timesteps)): - block_state.controlnet_keep.append( - 1.0 - - float(i / len(block_state.timesteps) < block_state.control_guidance_start or (i + 1) / len(block_state.timesteps) > block_state.control_guidance_end) - ) + # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) + block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) - # (2) Prepare conditional inputs for unet using the guider - # adding default guider arguments: disable_guidance, guidance_scale, guidance_rescale + # Setup guider + # disable for LCMs block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False if block_state.disable_guidance: components.guider.disable() else: components.guider.enable() - block_state.control_type = block_state.control_type.reshape(1, -1).to(block_state.device, dtype=block_state.prompt_embeds.dtype) - repeat_by = block_state.batch_size * block_state.num_images_per_prompt // block_state.control_type.shape[0] - block_state.control_type = block_state.control_type.repeat_interleave(repeat_by, dim=0) - # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) @@ -3612,7 +3742,7 @@ def description(self): # Denoise class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLControlNetUnionDenoiseStep, StableDiffusionXLControlNetStep, StableDiffusionXLDenoiseStep] + block_classes = [StableDiffusionXLControlNetUnionStep, StableDiffusionXLControlNetStep, StableDiffusionXLDenoiseStep] block_names = ["controlnet_union", "controlnet", "unet"] block_trigger_inputs = ["control_mode", "control_image", None] @@ -3620,8 +3750,8 @@ class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): def description(self): return "Denoise step that denoise the latents.\n" + \ "This is an auto pipeline block that works for controlnet, controlnet_union and no controlnet.\n" + \ - " - `StableDiffusionXLControlNetUnionDenoiseStep` (controlnet_union) is used when both `control_mode` and `control_image` are provided.\n" + \ - " - `StableDiffusionXLControlStep` (controlnet) is used when `control_image` is provided.\n" + \ + " - `StableDiffusionXLControlNetUnionStep` (controlnet_union) is used when both `control_mode` and `control_image` are provided.\n" + \ + " - `StableDiffusionXLControlNetStep` (controlnet) is used when `control_image` is provided.\n" + \ " - `StableDiffusionXLDenoiseStep` (unet only) is used when both `control_mode` and `control_image` are not provided." # After denoise @@ -3733,7 +3863,7 @@ def description(self): ]) CONTROLNET_UNION_BLOCKS = OrderedDict([ - ("denoise", StableDiffusionXLControlNetUnionDenoiseStep), + ("denoise", StableDiffusionXLControlNetUnionStep), ]) IP_ADAPTER_BLOCKS = OrderedDict([ @@ -3865,3 +3995,15 @@ def description(self): SDXL_OUTPUTS_SCHEMA = { "images": OutputParam("images", type_hint=Union[Tuple[Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]]], StableDiffusionXLPipelineOutput], description="The final generated images") } + + +class StableDiffusionXLControlNetUnionStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLControlNetUnionInputStep, StableDiffusionXLControlNetUnionDenoiseStep] + block_names = ["prepare_input", "denoise"] + + @property + def description(self): + return "ControlNetUnion step that denoises the latents.\n" + \ + "This is a sequential pipeline blocks:\n" + \ + " - `StableDiffusionXLControlNetUnionInputStep` is used to prepare the inputs for the denoise step.\n" + \ + " - `StableDiffusionXLControlNetUnionDenoiseStep` is used to denoise the latents using the ControlNetUnion model." From dc4dbfe10711f4f4e70c435a996cfedec00e5218 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 6 May 2025 09:58:44 +0200 Subject: [PATCH 06/38] reefactor pipeline/block states so that it can dynamically accept kwargs --- src/diffusers/pipelines/modular_pipeline.py | 153 ++++++++++++++---- .../pipelines/modular_pipeline_utils.py | 4 +- 2 files changed, 127 insertions(+), 30 deletions(-) diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index c994b91ba8bb..1733ad6d4e00 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -73,18 +73,72 @@ class PipelineState: inputs: Dict[str, Any] = field(default_factory=dict) intermediates: Dict[str, Any] = field(default_factory=dict) + input_kwargs: Dict[str, list[str, Any]] = field(default_factory=dict) + intermediate_kwargs: Dict[str, list[str, Any]] = field(default_factory=dict) - def add_input(self, key: str, value: Any): + def add_input(self, key: str, value: Any, kwargs_type: str = None): + """ + Add an input to the pipeline state with optional metadata. + + Args: + key (str): The key for the input + value (Any): The input value + kwargs_type (str): The kwargs_type to store with the input + """ self.inputs[key] = value + if kwargs_type is not None: + if kwargs_type not in self.input_kwargs: + self.input_kwargs[kwargs_type] = [key] + else: + self.input_kwargs[kwargs_type].append(key) - def add_intermediate(self, key: str, value: Any): + def add_intermediate(self, key: str, value: Any, kwargs_type: str = None): + """ + Add an intermediate value to the pipeline state with optional metadata. + + Args: + key (str): The key for the intermediate value + value (Any): The intermediate value + kwargs_type (str): The kwargs_type to store with the intermediate value + """ self.intermediates[key] = value + if kwargs_type is not None: + if kwargs_type not in self.intermediate_kwargs: + self.intermediate_kwargs[kwargs_type] = [key] + else: + self.intermediate_kwargs[kwargs_type].append(key) def get_input(self, key: str, default: Any = None) -> Any: return self.inputs.get(key, default) def get_inputs(self, keys: List[str], default: Any = None) -> Dict[str, Any]: return {key: self.inputs.get(key, default) for key in keys} + + def get_inputs_kwargs(self, kwargs_type: str) -> Dict[str, Any]: + """ + Get all inputs with matching kwargs_type. + + Args: + kwargs_type (str): The kwargs_type to filter by + + Returns: + Dict[str, Any]: Dictionary of inputs with matching kwargs_type + """ + input_names = self.input_kwargs.get(kwargs_type, []) + return self.get_inputs(input_names) + + def get_intermediates_kwargs(self, kwargs_type: str) -> Dict[str, Any]: + """ + Get all intermediates with matching kwargs_type. + + Args: + kwargs_type (str): The kwargs_type to filter by + + Returns: + Dict[str, Any]: Dictionary of intermediates with matching kwargs_type + """ + intermediate_names = self.intermediate_kwargs.get(kwargs_type, []) + return self.get_intermediates(intermediate_names) def get_intermediate(self, key: str, default: Any = None) -> Any: return self.intermediates.get(key, default) @@ -106,11 +160,17 @@ def format_value(v): inputs = "\n".join(f" {k}: {format_value(v)}" for k, v in self.inputs.items()) intermediates = "\n".join(f" {k}: {format_value(v)}" for k, v in self.intermediates.items()) + + # Format input_kwargs and intermediate_kwargs + input_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.input_kwargs.items()) + intermediate_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.intermediate_kwargs.items()) return ( f"PipelineState(\n" f" inputs={{\n{inputs}\n }},\n" - f" intermediates={{\n{intermediates}\n }}\n" + f" intermediates={{\n{intermediates}\n }},\n" + f" input_kwargs={{\n{input_kwargs_str}\n }},\n" + f" intermediate_kwargs={{\n{intermediate_kwargs_str}\n }}\n" f")" ) @@ -146,10 +206,16 @@ def format_value(v): # Handle dicts with tensor values elif isinstance(v, dict): - if any(hasattr(val, "shape") and hasattr(val, "dtype") for val in v.values()): - shapes = {k: val.shape for k, val in v.items() if hasattr(val, "shape")} - return f"Dict of Tensors with shapes {shapes}" - return repr(v) + formatted_dict = {} + for k, val in v.items(): + if hasattr(val, "shape") and hasattr(val, "dtype"): + formatted_dict[k] = f"Tensor(shape={val.shape}, dtype={val.dtype})" + elif isinstance(val, list) and len(val) > 0 and hasattr(val[0], "shape") and hasattr(val[0], "dtype"): + shapes = [t.shape for t in val] + formatted_dict[k] = f"List[{len(val)}] of Tensors with shapes {shapes}" + else: + formatted_dict[k] = repr(val) + return formatted_dict # Default case return repr(v) @@ -203,30 +269,34 @@ def run(self, state: PipelineState = None, output: Union[str, List[str]] = None, self.loader = None # Make a copy of the input kwargs - input_params = kwargs.copy() + passed_kwargs = kwargs.copy() - default_params = self.default_call_parameters # Add inputs to state, using defaults if not provided in the kwargs or the state # if same input already in the state, will override it if provided in the kwargs intermediates_inputs = [inp.name for inp in self.intermediates_inputs] - for name, default in default_params.items(): - if name in input_params: + for expected_input_param in self.inputs: + name = expected_input_param.name + default = expected_input_param.default + kwargs_type = expected_input_param.kwargs_type + if name in passed_kwargs: if name not in intermediates_inputs: - state.add_input(name, input_params.pop(name)) + state.add_input(name, passed_kwargs.pop(name), kwargs_type) else: - state.add_input(name, input_params[name]) + state.add_input(name, passed_kwargs[name], kwargs_type) elif name not in state.inputs: - state.add_input(name, default) + state.add_input(name, default, kwargs_type) - for name in intermediates_inputs: - if name in input_params: - state.add_intermediate(name, input_params.pop(name)) + for expected_intermediate_param in self.intermediates_inputs: + name = expected_intermediate_param.name + kwargs_type = expected_intermediate_param.kwargs_type + if name in passed_kwargs: + state.add_intermediate(name, passed_kwargs.pop(name), kwargs_type) # Warn about unexpected inputs - if len(input_params) > 0: - logger.warning(f"Unexpected input '{input_params.keys()}' provided. This input will be ignored.") + if len(passed_kwargs) > 0: + logger.warning(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.") # Run the pipeline with torch.no_grad(): try: @@ -390,25 +460,50 @@ def get_block_state(self, state: PipelineState) -> dict: # Check inputs for input_param in self.inputs: - value = state.get_input(input_param.name) - if input_param.required and value is None: - raise ValueError(f"Required input '{input_param.name}' is missing") - data[input_param.name] = value + if input_param.name: + value = state.get_input(input_param.name) + if input_param.required and value is None: + raise ValueError(f"Required input '{input_param.name}' is missing") + elif value is not None or (value is None and input_param.name not in data): + data[input_param.name] = value + elif input_param.kwargs_type: + # if kwargs_type is provided, get all inputs with matching kwargs_type + if input_param.kwargs_type not in data: + data[input_param.kwargs_type] = {} + inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type) + if inputs_kwargs: + for k, v in inputs_kwargs.items(): + if v is not None: + data[k] = v + data[input_param.kwargs_type][k] = v # Check intermediates for input_param in self.intermediates_inputs: - value = state.get_intermediate(input_param.name) - if input_param.required and value is None: - raise ValueError(f"Required intermediate input '{input_param.name}' is missing") - data[input_param.name] = value - + if input_param.name: + value = state.get_intermediate(input_param.name) + if input_param.required and value is None: + raise ValueError(f"Required intermediate input '{input_param.name}' is missing") + elif value is not None or (value is None and input_param.name not in data): + data[input_param.name] = value + elif input_param.kwargs_type: + # if kwargs_type is provided, get all intermediates with matching kwargs_type + if input_param.kwargs_type not in data: + data[input_param.kwargs_type] = {} + intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type) + if intermediates_kwargs: + for k, v in intermediates_kwargs.items(): + if v is not None: + if k not in data: + data[k] = v + data[input_param.kwargs_type][k] = v return BlockState(**data) def add_block_state(self, state: PipelineState, block_state: BlockState): for output_param in self.intermediates_outputs: if not hasattr(block_state, output_param.name): raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") - state.add_intermediate(output_param.name, getattr(block_state, output_param.name)) + param = getattr(block_state, output_param.name) + state.add_intermediate(output_param.name, param, output_param.kwargs_type) def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]: diff --git a/src/diffusers/pipelines/modular_pipeline_utils.py b/src/diffusers/pipelines/modular_pipeline_utils.py index c8064a5215aa..f300f259f9eb 100644 --- a/src/diffusers/pipelines/modular_pipeline_utils.py +++ b/src/diffusers/pipelines/modular_pipeline_utils.py @@ -244,11 +244,12 @@ class ConfigSpec: @dataclass class InputParam: """Specification for an input parameter.""" - name: str + name: str = None type_hint: Any = None default: Any = None required: bool = False description: str = "" + kwargs_type: str = None def __repr__(self): return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>" @@ -260,6 +261,7 @@ class OutputParam: name: str type_hint: Any = None description: str = "" + kwargs_type: str = None def __repr__(self): return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>" From f552773572a9a27d80aa35910e45c26883583bc5 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 6 May 2025 10:00:14 +0200 Subject: [PATCH 07/38] remove controlnet union denoise step, refactor & reuse controlnet denoisee step to accept aditional contrlnet kwargs --- .../pipeline_stable_diffusion_xl_modular.py | 475 +++--------------- 1 file changed, 57 insertions(+), 418 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 5ebdd383ccbb..119c92e06f1d 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -2452,11 +2452,11 @@ def intermediates_inputs(self) -> List[str]: @property def intermediates_outputs(self) -> List[OutputParam]: return [ - OutputParam("control_image", type_hint=torch.Tensor, description="The processed control image"), + OutputParam("controlnet_cond", type_hint=torch.Tensor, description="The processed control image", kwargs_type="contronet_kwargs"), OutputParam("control_guidance_start", type_hint=List[float], description="The controlnet guidance start values"), OutputParam("control_guidance_end", type_hint=List[float], description="The controlnet guidance end values"), - OutputParam("controlnet_conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values"), - OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"), + OutputParam("conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values"), + OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used", kwargs_type="controlnet_kwargs"), OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"), ] @@ -2582,6 +2582,9 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt for s, e in zip(block_state.control_guidance_start, block_state.control_guidance_end) ] block_state.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + block_state.controlnet_cond = block_state.control_image + block_state.conditioning_scale = block_state.controlnet_conditioning_scale @@ -2615,15 +2618,16 @@ def inputs(self) -> List[Tuple[str, Any]]: return [ InputParam("num_images_per_prompt", default=1), InputParam("cross_attention_kwargs"), - InputParam("generator"), - InputParam("eta", default=0.0), + InputParam("generator", kwargs_type="scheduler_kwargs"), + InputParam("eta", default=0.0, kwargs_type="scheduler_kwargs"), + InputParam("controlnet_conditioning_scale", type_hint=float, default=1.0), # can expect either input or intermediate input, (intermediate input if both are passed) ] @property def intermediates_inputs(self) -> List[str]: return [ InputParam( - "control_image", + "controlnet_cond", required=True, type_hint=torch.Tensor, description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step." @@ -2641,8 +2645,7 @@ def intermediates_inputs(self) -> List[str]: description="The control guidance end value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." ), InputParam( - "controlnet_conditioning_scale", - required=True, + "conditioning_scale", type_hint=float, description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." ), @@ -2755,6 +2758,7 @@ def intermediates_inputs(self) -> List[str]: type_hint=int, description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." ), + InputParam(kwargs_type="controlnet_kwargs", description="additional kwargs for controlnet") ] @property @@ -2780,26 +2784,16 @@ def check_inputs(components, block_state): f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" " `components.unet` or your `mask_image` or `image` input." ) - - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components @staticmethod - def prepare_extra_step_kwargs(components, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] + def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): - accepts_eta = "eta" in set(inspect.signature(components.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta + accepted_kwargs = set(inspect.signature(func).parameters.keys()) + extra_kwargs = {} + for key, value in kwargs.items(): + if key in accepted_kwargs and key not in exclude_kwargs: + extra_kwargs[key] = value - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(components.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs + return extra_kwargs @torch.no_grad() @@ -2808,9 +2802,15 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt block_state = self.get_block_state(state) self.check_inputs(components, block_state) block_state.device = components._execution_device + print(f" block_state: {block_state}") + + controlnet = unwrap_module(components.controlnet) # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) + # YiYI TODO: refactor scheduler_kwargs and support unet kwargs + block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) + block_state.extra_controlnet_kwargs = self.prepare_extra_kwargs(controlnet.forward, exclude_kwargs=["controlnet_cond", "conditioning_scale", "guess_mode"], **block_state.controlnet_kwargs) + block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) # (1) setup guider @@ -2841,9 +2841,9 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt # cond_scale (controlnet input) if isinstance(block_state.controlnet_keep[i], list): - block_state.cond_scale = [c * s for c, s in zip(block_state.controlnet_conditioning_scale, block_state.controlnet_keep[i])] + block_state.cond_scale = [c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i])] else: - block_state.controlnet_cond_scale = block_state.controlnet_conditioning_scale + block_state.controlnet_cond_scale = block_state.conditioning_scale if isinstance(block_state.controlnet_cond_scale, list): block_state.controlnet_cond_scale = block_state.controlnet_cond_scale[0] block_state.cond_scale = block_state.controlnet_cond_scale * block_state.controlnet_keep[i] @@ -2882,11 +2882,12 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt block_state.scaled_latents, t, encoder_hidden_states=guider_state_batch.prompt_embeds, - controlnet_cond=block_state.control_image, - conditioning_scale=block_state.cond_scale, + controlnet_cond=block_state.controlnet_cond, + conditioning_scale=block_state.conditioning_scale, guess_mode=block_state.guess_mode, added_cond_kwargs=guider_state_batch.controlnet_added_cond_kwargs, return_dict=False, + **block_state.extra_controlnet_kwargs, ) if block_state.down_block_res_samples_zeros is None: @@ -2958,7 +2959,7 @@ def description(self) -> str: def inputs(self) -> List[Tuple[str, Any]]: return [ InputParam("control_image", required=True), - InputParam("control_mode", default=[0]), + InputParam("control_mode", required=True), InputParam("control_guidance_start", default=0.0), InputParam("control_guidance_end", default=1.0), InputParam("controlnet_conditioning_scale", default=1.0), @@ -2973,7 +2974,7 @@ def intermediates_inputs(self) -> List[InputParam]: "latents", required=True, type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + description="The initial latents to use for the denoising process. Used to determine the shape of the control images. Can be generated in prepare_latent step." ), InputParam( "batch_size", @@ -2991,7 +2992,7 @@ def intermediates_inputs(self) -> List[InputParam]: "timesteps", required=True, type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." + description="The timesteps to use for the denoising process. Needed to determine `controlnet_keep`. Can be generated in set_timesteps step." ), InputParam( "crops_coords", @@ -3003,13 +3004,13 @@ def intermediates_inputs(self) -> List[InputParam]: @property def intermediates_outputs(self) -> List[OutputParam]: return [ - OutputParam("control_image", type_hint=List[torch.Tensor], description="The processed control images"), - OutputParam("control_mode", type_hint=List[int], description="The control mode indices"), - OutputParam("control_type", type_hint=torch.Tensor, description="The control type tensor that specifies which control type is active"), + OutputParam("controlnet_cond", type_hint=List[torch.Tensor], description="The processed control images", kwargs_type="controlnet_kwargs"), + OutputParam("control_type_idx", type_hint=List[int], description="The control mode indices", kwargs_type="controlnet_kwargs"), + OutputParam("control_type", type_hint=torch.Tensor, description="The control type tensor that specifies which control type is active", kwargs_type="controlnet_kwargs"), OutputParam("control_guidance_start", type_hint=float, description="The controlnet guidance start value"), OutputParam("control_guidance_end", type_hint=float, description="The controlnet guidance end value"), - OutputParam("controlnet_conditioning_scale", type_hint=float, description="The controlnet conditioning scale value"), - OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"), + OutputParam("conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values"), + OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used", kwargs_type="controlnet_kwargs"), OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"), ] @@ -3051,7 +3052,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt controlnet = unwrap_module(components.controlnet) - device = block_state.device or components._execution_device + device = components._execution_device dtype = block_state.dtype or components.controlnet.dtype block_state.height, block_state.width = block_state.latents.shape[-2:] @@ -3069,10 +3070,10 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt block_state.global_pool_conditions = controlnet.config.global_pool_conditions block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions - + # control_image if not isinstance(block_state.control_image, list): block_state.control_image = [block_state.control_image] - + # control_mode if not isinstance(block_state.control_mode, list): block_state.control_mode = [block_state.control_mode] @@ -3112,371 +3113,9 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt 1.0 - float(i / len(block_state.timesteps) < block_state.control_guidance_start or (i + 1) / len(block_state.timesteps) > block_state.control_guidance_end) ) - - - self.add_block_state(state, block_state) - - return components, state - -class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config"), - ComponentSpec("scheduler", EulerDiscreteScheduler), - ComponentSpec("unet", UNet2DConditionModel), - ComponentSpec("controlnet", ControlNetUnionModel), - ] - - @property - def description(self) -> str: - return " The denoising step for the controlnet union model, works for inpainting, image-to-image, and text-to-image tasks" - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("num_images_per_prompt", default=1), - InputParam("cross_attention_kwargs"), - InputParam("generator"), - InputParam("eta", default=0.0), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "control_image", - required=True, - type_hint=List[torch.Tensor], - description="The control images to use for conditioning. Can be generated in prepare controlnet inputs step." - ), - InputParam( - "control_mode", - required=True, - type_hint=List[int], - description="The control mode indices. Can be generated in prepare controlnet inputs step." - ), - InputParam( - "control_type", - required=True, - type_hint=torch.Tensor, - description="The control type tensor that specifies which control type is active. Can be generated in prepare controlnet inputs step." - ), - InputParam( - "num_control_type", - required=True, - type_hint=int, - description="The number of control types available. Can be generated in prepare controlnet inputs step." - ), - InputParam( - "control_guidance_start", - required=True, - type_hint=float, - description="The control guidance start value. Can be generated in prepare controlnet inputs step." - ), - InputParam( - "control_guidance_end", - required=True, - type_hint=float, - description="The control guidance end value. Can be generated in prepare controlnet inputs step." - ), - InputParam( - "controlnet_conditioning_scale", - required=True, - type_hint=float, - description="The controlnet conditioning scale value. Can be generated in prepare controlnet inputs step." - ), - InputParam( - "guess_mode", - required=True, - type_hint=bool, - description="Whether guess mode is used. Can be generated in prepare controlnet inputs step." - ), - InputParam( - "global_pool_conditions", - required=True, - type_hint=bool, - description="Whether global pool conditions are used. Can be generated in prepare controlnet inputs step." - ), - InputParam( - "controlnet_keep", - required=True, - type_hint=List[float], - description="The controlnet keep values. Can be generated in prepare controlnet inputs step." - ), - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." - ), - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), - InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "negative_prompt_embeds", - type_hint=Optional[torch.Tensor], - description="The negative prompt embeddings used to condition the denoising process. Can be generated in text_encoder step. See: https://github.com/huggingface/diffusers/issues/4208" - ), - InputParam( - "add_time_ids", - required=True, - type_hint=torch.Tensor, - description="The time ids used to condition the denoising process. Can be generated in prepare_additional_conditioning step." - ), - InputParam( - "negative_add_time_ids", - type_hint=Optional[torch.Tensor], - description="The negative time ids used to condition the denoising process. Can be generated in prepare_additional_conditioning step. " - ), - InputParam( - "pooled_prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The pooled prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "negative_pooled_prompt_embeds", - type_hint=Optional[torch.Tensor], - description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. See: https://github.com/huggingface/diffusers/issues/4208" - ), - InputParam( - "timestep_cond", - type_hint=Optional[torch.Tensor], - description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." - ), - InputParam( - "mask", - type_hint=Optional[torch.Tensor], - description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "masked_image_latents", - type_hint=Optional[torch.Tensor], - description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "noise", - type_hint=Optional[torch.Tensor], - description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." - ), - InputParam( - "image_latents", - type_hint=Optional[torch.Tensor], - description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "crops_coords", - type_hint=Optional[Tuple[int]], - description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "ip_adapter_embeds", - type_hint=Optional[torch.Tensor], - description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." - ), - InputParam( - "negative_ip_adapter_embeds", - type_hint=Optional[torch.Tensor], - description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." - ), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - - @staticmethod - def check_inputs(components, block_state): - - num_channels_unet = components.unet.config.in_channels - if num_channels_unet == 9: - # default case for runwayml/stable-diffusion-inpainting - if block_state.mask is None or block_state.masked_image_latents is None: - raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") - num_channels_latents = block_state.latents.shape[1] - num_channels_mask = block_state.mask.shape[1] - num_channels_masked_image = block_state.masked_image_latents.shape[1] - if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: - raise ValueError( - f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" - f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" - f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" - " `components.unet` or your `mask_image` or `image` input." - ) - - - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components - @staticmethod - def prepare_extra_step_kwargs(components, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(components.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(components.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - @torch.no_grad() - def __call__(self, components, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - self.check_inputs(components, block_state) - - block_state.num_channels_unet = components.unet.config.in_channels - block_state.device = components._execution_device - - # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) - block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) - - # Setup guider - # disable for LCMs - block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False - if block_state.disable_guidance: - components.guider.disable() - else: - components.guider.enable() - - # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) - block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) - - components.guider.set_input_fields( - prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), - add_time_ids=("add_time_ids", "negative_add_time_ids"), - pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), - ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), - ) - - with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: - for i, t in enumerate(block_state.timesteps): - components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) - guider_data = components.guider.prepare_inputs(block_state) - - block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) - - if isinstance(block_state.controlnet_keep[i], list): - block_state.cond_scale = [c * s for c, s in zip(block_state.controlnet_conditioning_scale, block_state.controlnet_keep[i])] - else: - block_state.controlnet_cond_scale = block_state.controlnet_conditioning_scale - if isinstance(block_state.controlnet_cond_scale, list): - block_state.controlnet_cond_scale = block_state.controlnet_cond_scale[0] - block_state.cond_scale = block_state.controlnet_cond_scale * block_state.controlnet_keep[i] - - for batch in guider_data: - components.guider.prepare_models(components.unet) - - # Prepare additional conditionings - batch.added_cond_kwargs = { - "text_embeds": batch.pooled_prompt_embeds, - "time_ids": batch.add_time_ids, - } - if batch.ip_adapter_embeds is not None: - batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds - - # Prepare controlnet additional conditionings - batch.controlnet_added_cond_kwargs = { - "text_embeds": batch.pooled_prompt_embeds, - "time_ids": batch.add_time_ids, - } - - # Will always be run atleast once with every guider - if components.guider.is_conditional or not block_state.guess_mode: - block_state.down_block_res_samples, block_state.mid_block_res_sample = components.controlnet( - block_state.scaled_latents, - t, - encoder_hidden_states=batch.prompt_embeds, - controlnet_cond=block_state.control_image, - control_type=block_state.control_type, - control_type_idx=block_state.control_mode, - conditioning_scale=block_state.cond_scale, - guess_mode=block_state.guess_mode, - added_cond_kwargs=batch.controlnet_added_cond_kwargs, - return_dict=False, - ) - - batch.down_block_res_samples = block_state.down_block_res_samples - batch.mid_block_res_sample = block_state.mid_block_res_sample - - if components.guider.is_unconditional and block_state.guess_mode: - batch.down_block_res_samples = [torch.zeros_like(d) for d in block_state.down_block_res_samples] - batch.mid_block_res_sample = torch.zeros_like(block_state.mid_block_res_sample) - - if block_state.num_channels_unet == 9: - block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) - - batch.noise_pred = components.unet( - block_state.scaled_latents, - t, - encoder_hidden_states=batch.prompt_embeds, - timestep_cond=block_state.timestep_cond, - cross_attention_kwargs=block_state.cross_attention_kwargs, - added_cond_kwargs=batch.added_cond_kwargs, - down_block_additional_residuals=batch.down_block_res_samples, - mid_block_additional_residual=batch.mid_block_res_sample, - return_dict=False, - )[0] - components.guider.cleanup_models(components.unet) - - # Perform guidance - block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_data) - - # Perform scheduler step using the predicted output - block_state.latents_dtype = block_state.latents.dtype - block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] - - if block_state.latents.dtype != block_state.latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - block_state.latents = block_state.latents.to(block_state.latents_dtype) - - if block_state.num_channels_unet == 9 and block_state.mask is not None and block_state.image_latents is not None: - block_state.init_latents_proper = block_state.image_latents - if i < len(block_state.timesteps) - 1: - block_state.noise_timestep = block_state.timesteps[i + 1] - block_state.init_latents_proper = components.scheduler.add_noise( - block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) - ) - block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents - - if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): - progress_bar.update() + block_state.control_type_idx = block_state.control_mode + block_state.controlnet_cond = block_state.control_image + block_state.conditioning_scale = block_state.controlnet_conditioning_scale self.add_block_state(state, block_state) @@ -3727,6 +3366,18 @@ def description(self): " - `StableDiffusionXLControlNetInputStep` is used to prepare the inputs for the denoise step.\n" + \ " - `StableDiffusionXLControlNetDenoiseStep` is used to denoise the latents." +class StableDiffusionXLControlNetUnionStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLControlNetUnionInputStep, StableDiffusionXLControlNetDenoiseStep] + block_names = ["prepare_input", "denoise"] + + @property + def description(self): + return "ControlNetUnion step that denoises the latents.\n" + \ + "This is a sequential pipeline blocks:\n" + \ + " - `StableDiffusionXLControlNetUnionInputStep` is used to prepare the inputs for the denoise step.\n" + \ + " - `StableDiffusionXLControlNetDenoiseStep` is used to denoise the latents using the ControlNetUnion model." + + class StableDiffusionXLAutoBeforeDenoiseStep(AutoPipelineBlocks): block_classes = [StableDiffusionXLInpaintBeforeDenoiseStep, StableDiffusionXLImg2ImgBeforeDenoiseStep, StableDiffusionXLBeforeDenoiseStep] block_names = ["inpaint", "img2img", "text2img"] @@ -3995,15 +3646,3 @@ def description(self): SDXL_OUTPUTS_SCHEMA = { "images": OutputParam("images", type_hint=Union[Tuple[Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]]], StableDiffusionXLPipelineOutput], description="The final generated images") } - - -class StableDiffusionXLControlNetUnionStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLControlNetUnionInputStep, StableDiffusionXLControlNetUnionDenoiseStep] - block_names = ["prepare_input", "denoise"] - - @property - def description(self): - return "ControlNetUnion step that denoises the latents.\n" + \ - "This is a sequential pipeline blocks:\n" + \ - " - `StableDiffusionXLControlNetUnionInputStep` is used to prepare the inputs for the denoise step.\n" + \ - " - `StableDiffusionXLControlNetUnionDenoiseStep` is used to denoise the latents using the ControlNetUnion model." From 16b6583fa8bcd0d5595984ac4c7f08c91ab3af3f Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 8 May 2025 11:25:31 +0200 Subject: [PATCH 08/38] allow input_fields as input & update message --- src/diffusers/guiders/adaptive_projected_guidance.py | 10 +++++++--- src/diffusers/guiders/auto_guidance.py | 10 +++++++--- src/diffusers/guiders/classifier_free_guidance.py | 10 +++++++--- .../guiders/classifier_free_zero_star_guidance.py | 10 +++++++--- src/diffusers/guiders/guider_utils.py | 4 ++-- src/diffusers/guiders/skip_layer_guidance.py | 10 +++++++--- src/diffusers/guiders/smoothed_energy_guidance.py | 10 +++++++--- .../guiders/tangential_classifier_free_guidance.py | 10 +++++++--- 8 files changed, 51 insertions(+), 23 deletions(-) diff --git a/src/diffusers/guiders/adaptive_projected_guidance.py b/src/diffusers/guiders/adaptive_projected_guidance.py index 7da1cc59a365..83e93c15ff1d 100644 --- a/src/diffusers/guiders/adaptive_projected_guidance.py +++ b/src/diffusers/guiders/adaptive_projected_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import Optional, List, TYPE_CHECKING +from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple import torch @@ -73,14 +73,18 @@ def __init__( self.use_original_formulation = use_original_formulation self.momentum_buffer = None - def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + if self._step == 0: if self.adaptive_projected_guidance_momentum is not None: self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum) tuple_indices = [0] if self.num_conditions == 1 else [0, 1] data_batches = [] for i in range(self.num_conditions): - data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) data_batches.append(data_batch) return data_batches diff --git a/src/diffusers/guiders/auto_guidance.py b/src/diffusers/guiders/auto_guidance.py index bfffb9f39cd2..8bb6083781c2 100644 --- a/src/diffusers/guiders/auto_guidance.py +++ b/src/diffusers/guiders/auto_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import List, Optional, Union, TYPE_CHECKING +from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple import torch @@ -120,11 +120,15 @@ def cleanup_models(self, denoiser: torch.nn.Module) -> None: registry = HookRegistry.check_if_exists_or_initialize(denoiser) registry.remove_hook(name, recurse=True) - def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] data_batches = [] for i in range(self.num_conditions): - data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) data_batches.append(data_batch) return data_batches diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py index 429f8450410a..429392e3f9c6 100644 --- a/src/diffusers/guiders/classifier_free_guidance.py +++ b/src/diffusers/guiders/classifier_free_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import Optional, List, TYPE_CHECKING +from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple import torch @@ -75,11 +75,15 @@ def __init__( self.guidance_rescale = guidance_rescale self.use_original_formulation = use_original_formulation - def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] data_batches = [] for i in range(self.num_conditions): - data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) data_batches.append(data_batch) return data_batches diff --git a/src/diffusers/guiders/classifier_free_zero_star_guidance.py b/src/diffusers/guiders/classifier_free_zero_star_guidance.py index 4c9839ee78f3..220a95e54a8d 100644 --- a/src/diffusers/guiders/classifier_free_zero_star_guidance.py +++ b/src/diffusers/guiders/classifier_free_zero_star_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import Optional, List, TYPE_CHECKING +from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple import torch @@ -73,11 +73,15 @@ def __init__( self.guidance_rescale = guidance_rescale self.use_original_formulation = use_original_formulation - def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] data_batches = [] for i in range(self.num_conditions): - data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) data_batches.append(data_batch) return data_batches diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index 7d005442e89c..18c85f579424 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -174,7 +174,7 @@ def _prepare_batch(cls, input_fields: Dict[str, Union[str, Tuple[str, str]]], da from ..pipelines.modular_pipeline import BlockState if input_fields is None: - raise ValueError("Input fields have not been set. Please call `set_input_fields` before preparing inputs.") + raise ValueError("Input fields cannot be None. Please pass `input_fields` to `prepare_inputs` or call `set_input_fields` before preparing inputs.") data_batch = {} for key, value in input_fields.items(): try: @@ -186,7 +186,7 @@ def _prepare_batch(cls, input_fields: Dict[str, Union[str, Tuple[str, str]]], da # We've already checked that value is a string or a tuple of strings with length 2 pass except AttributeError: - raise ValueError(f"Expected `data` to have attribute(s) {value}, but it does not. Please check the input data.") + logger.warning(f"`data` does not have attribute(s) {value}, skipping.") data_batch[cls._identifier_key] = identifier return BlockState(**data_batch) diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py index bdd9e4af81b6..56dae1903606 100644 --- a/src/diffusers/guiders/skip_layer_guidance.py +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import List, Optional, Union, TYPE_CHECKING +from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple import torch @@ -156,7 +156,11 @@ def cleanup_models(self, denoiser: torch.nn.Module) -> None: for hook_name in self._skip_layer_hook_names: registry.remove_hook(hook_name, recurse=True) - def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + if self.num_conditions == 1: tuple_indices = [0] input_predictions = ["pred_cond"] @@ -168,7 +172,7 @@ def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"] data_batches = [] for i in range(self.num_conditions): - data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], input_predictions[i]) + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i]) data_batches.append(data_batch) return data_batches diff --git a/src/diffusers/guiders/smoothed_energy_guidance.py b/src/diffusers/guiders/smoothed_energy_guidance.py index 1c7ee45dc3db..c215cb0afdc9 100644 --- a/src/diffusers/guiders/smoothed_energy_guidance.py +++ b/src/diffusers/guiders/smoothed_energy_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import List, Optional, Union, TYPE_CHECKING +from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple import torch @@ -149,7 +149,11 @@ def cleanup_models(self, denoiser: torch.nn.Module): for hook_name in self._seg_layer_hook_names: registry.remove_hook(hook_name, recurse=True) - def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + if self.num_conditions == 1: tuple_indices = [0] input_predictions = ["pred_cond"] @@ -161,7 +165,7 @@ def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"] data_batches = [] for i in range(self.num_conditions): - data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], input_predictions[i]) + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i]) data_batches.append(data_batch) return data_batches diff --git a/src/diffusers/guiders/tangential_classifier_free_guidance.py b/src/diffusers/guiders/tangential_classifier_free_guidance.py index 631f9a5f33b2..9fa8f9454134 100644 --- a/src/diffusers/guiders/tangential_classifier_free_guidance.py +++ b/src/diffusers/guiders/tangential_classifier_free_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import Optional, List, TYPE_CHECKING +from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple import torch @@ -62,11 +62,15 @@ def __init__( self.guidance_rescale = guidance_rescale self.use_original_formulation = use_original_formulation - def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] data_batches = [] for i in range(self.num_conditions): - data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) data_batches.append(data_batch) return data_batches From d89631fc50578dc5de0b95400b7d796daa8b0abc Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 8 May 2025 11:27:17 +0200 Subject: [PATCH 09/38] update input formating, consider kwarggs_type inputs with no name, e/g *_controlnet_kwargs --- src/diffusers/pipelines/modular_pipeline_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/modular_pipeline_utils.py b/src/diffusers/pipelines/modular_pipeline_utils.py index f300f259f9eb..392d6dcd9521 100644 --- a/src/diffusers/pipelines/modular_pipeline_utils.py +++ b/src/diffusers/pipelines/modular_pipeline_utils.py @@ -322,7 +322,11 @@ def format_intermediates_short(intermediates_inputs, required_intermediates_inpu if inp.name in required_intermediates_inputs: input_parts.append(f"Required({inp.name})") else: - input_parts.append(inp.name) + if inp.name is None and inp.kwargs_type is not None: + inp_name = "*_" + inp.kwargs_type + else: + inp_name = inp.name + input_parts.append(inp_name) # Handle modified variables (appear in both inputs and outputs) inputs_set = {inp.name for inp in intermediates_inputs} From 0f0618ff2b53397485fec8ca8c3fb698434019ef Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 8 May 2025 11:28:52 +0200 Subject: [PATCH 10/38] refactor the denoiseestep using LoopSequential! also add a new file for denoise step --- src/diffusers/pipelines/modular_pipeline.py | 288 +++- .../pipeline_stable_diffusion_xl_modular.py | 1185 +++++++++-------- ...table_diffusion_xl_modular_denoise_loop.py | 729 ++++++++++ 3 files changed, 1607 insertions(+), 595 deletions(-) create mode 100644 src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular_denoise_loop.py diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index 1733ad6d4e00..92cb50a8b490 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -184,6 +184,23 @@ def __init__(self, **kwargs): for key, value in kwargs.items(): setattr(self, key, value) + def __getitem__(self, key: str): + # allows block_state["foo"] + return getattr(self, key, None) + + def __setitem__(self, key: str, value: Any): + # allows block_state["foo"] = "bar" + setattr(self, key, value) + + def as_dict(self): + """ + Convert BlockState to a dictionary. + + Returns: + Dict[str, Any]: Dictionary containing all attributes of the BlockState + """ + return {key: value for key, value in self.__dict__.items()} + def __repr__(self): def format_value(v): # Handle tensors directly @@ -523,8 +540,12 @@ def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> Li for block_name, inputs in named_input_lists: for input_param in inputs: - if input_param.name in combined_dict: - current_param = combined_dict[input_param.name] + if input_param.name is None and input_param.kwargs_type is not None: + input_name = "*_" + input_param.kwargs_type + else: + input_name = input_param.name + if input_name in combined_dict: + current_param = combined_dict[input_name] if (current_param.default is not None and input_param.default is not None and current_param.default != input_param.default): @@ -557,7 +578,7 @@ def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> for block_name, outputs in named_output_lists: for output_param in outputs: - if output_param.name not in combined_dict: + if (output_param.name not in combined_dict) or (combined_dict[output_param.name].kwargs_type is None and output_param.kwargs_type is not None): combined_dict[output_param.name] = output_param return list(combined_dict.values()) @@ -919,6 +940,9 @@ def required_intermediates_inputs(self) -> List[str]: # YiYi TODO: add test for this @property def inputs(self) -> List[Tuple[str, Any]]: + return self.get_inputs() + + def get_inputs(self): named_inputs = [(name, block.inputs) for name, block in self.blocks.items()] combined_inputs = combine_inputs(*named_inputs) # mark Required inputs only if that input is required any of the blocks @@ -931,6 +955,9 @@ def inputs(self) -> List[Tuple[str, Any]]: @property def intermediates_inputs(self) -> List[str]: + return self.get_intermediates_inputs() + + def get_intermediates_inputs(self): inputs = [] outputs = set() @@ -1169,7 +1196,262 @@ def doc(self): expected_configs=self.expected_configs ) +#YiYi TODO: __repr__ +class LoopSequentialPipelineBlocks(ModularPipelineMixin): + """ + A class that combines multiple pipeline block classes into a For Loop. When called, it will call each block in sequence. + """ + + model_name = None + block_classes = [] + block_names = [] + + @property + def description(self) -> str: + """Description of the block. Must be implemented by subclasses.""" + raise NotImplementedError("description method must be implemented in subclasses") + + @property + def loop_expected_components(self) -> List[ComponentSpec]: + return [] + + @property + def loop_expected_configs(self) -> List[ConfigSpec]: + return [] + + @property + def loop_inputs(self) -> List[InputParam]: + """List of input parameters. Must be implemented by subclasses.""" + return [] + + @property + def loop_intermediates_inputs(self) -> List[InputParam]: + """List of intermediate input parameters. Must be implemented by subclasses.""" + return [] + + @property + def loop_intermediates_outputs(self) -> List[OutputParam]: + """List of intermediate output parameters. Must be implemented by subclasses.""" + return [] + + + @property + def loop_required_inputs(self) -> List[str]: + input_names = [] + for input_param in self.loop_inputs: + if input_param.required: + input_names.append(input_param.name) + return input_names + + @property + def loop_required_intermediates_inputs(self) -> List[str]: + input_names = [] + for input_param in self.loop_intermediates_inputs: + if input_param.required: + input_names.append(input_param.name) + return input_names + + # modified from SequentialPipelineBlocks to include loop_expected_components + @property + def expected_components(self): + expected_components = [] + for block in self.blocks.values(): + for component in block.expected_components: + if component not in expected_components: + expected_components.append(component) + for component in self.loop_expected_components: + if component not in expected_components: + expected_components.append(component) + return expected_components + + # modified from SequentialPipelineBlocks to include loop_expected_configs + @property + def expected_configs(self): + expected_configs = [] + for block in self.blocks.values(): + for config in block.expected_configs: + if config not in expected_configs: + expected_configs.append(config) + for config in self.loop_expected_configs: + if config not in expected_configs: + expected_configs.append(config) + return expected_configs + + # modified from SequentialPipelineBlocks to include loop_inputs + def get_inputs(self): + named_inputs = [(name, block.inputs) for name, block in self.blocks.items()] + named_inputs.append(("loop", self.loop_inputs)) + combined_inputs = combine_inputs(*named_inputs) + # mark Required inputs only if that input is required any of the blocks + for input_param in combined_inputs: + if input_param.name in self.required_inputs: + input_param.required = True + else: + input_param.required = False + return combined_inputs + + # Copied from SequentialPipelineBlocks + @property + def inputs(self): + return self.get_inputs() + + + # modified from SequentialPipelineBlocks to include loop_intermediates_inputs + @property + def intermediates_inputs(self): + intermediates = self.get_intermediates_inputs() + intermediate_names = [input.name for input in intermediates] + for loop_intermediate_input in self.loop_intermediates_inputs: + if loop_intermediate_input.name not in intermediate_names: + intermediates.append(loop_intermediate_input) + return intermediates + + + # Copied from SequentialPipelineBlocks + def get_intermediates_inputs(self): + inputs = [] + outputs = set() + + # Go through all blocks in order + for block in self.blocks.values(): + # Add inputs that aren't in outputs yet + inputs.extend(input_name for input_name in block.intermediates_inputs if input_name.name not in outputs) + + # Only add outputs if the block cannot be skipped + should_add_outputs = True + if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs: + should_add_outputs = False + + if should_add_outputs: + # Add this block's outputs + block_intermediates_outputs = [out.name for out in block.intermediates_outputs] + outputs.update(block_intermediates_outputs) + return inputs + + # modified from SequentialPipelineBlocks, if any additionan input required by the loop is required by the block + @property + def required_inputs(self) -> List[str]: + # Get the first block from the dictionary + first_block = next(iter(self.blocks.values())) + required_by_any = set(getattr(first_block, "required_inputs", set())) + + required_by_loop = set(getattr(self, "loop_required_inputs", set())) + required_by_any.update(required_by_loop) + + # Union with required inputs from all other blocks + for block in list(self.blocks.values())[1:]: + block_required = set(getattr(block, "required_inputs", set())) + required_by_any.update(block_required) + + return list(required_by_any) + + # modified from SequentialPipelineBlocks, if any additional intermediate input required by the loop is required by the block + @property + def required_intermediates_inputs(self) -> List[str]: + required_intermediates_inputs = [] + for input_param in self.intermediates_inputs: + if input_param.required: + required_intermediates_inputs.append(input_param.name) + for input_param in self.loop_intermediates_inputs: + if input_param.required: + required_intermediates_inputs.append(input_param.name) + return required_intermediates_inputs + + + # YiYi TODO: this need to be thought about more + # modified from SequentialPipelineBlocks to include loop_intermediates_outputs + @property + def intermediates_outputs(self) -> List[str]: + named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] + combined_outputs = combine_outputs(*named_outputs) + for output in self.loop_intermediates_outputs: + if output.name not in set([output.name for output in combined_outputs]): + combined_outputs.append(output) + return combined_outputs + + # YiYi TODO: this need to be thought about more + # copied from SequentialPipelineBlocks + @property + def outputs(self) -> List[str]: + return next(reversed(self.blocks.values())).intermediates_outputs + + + def __init__(self): + blocks = OrderedDict() + for block_name, block_cls in zip(self.block_names, self.block_classes): + blocks[block_name] = block_cls() + self.blocks = blocks + + def loop_step(self, components, state: PipelineState, **kwargs): + + for block_name, block in self.blocks.items(): + try: + components, state = block(components, state, **kwargs) + except Exception as e: + error_msg = ( + f"\nError in block: ({block_name}, {block.__class__.__name__})\n" + f"Error details: {str(e)}\n" + f"Traceback:\n{traceback.format_exc()}" + ) + logger.error(error_msg) + raise + return components, state + + def __call__(self, components, state: PipelineState) -> PipelineState: + raise NotImplementedError("`__call__` method needs to be implemented by the subclass") + + + def get_block_state(self, state: PipelineState) -> dict: + """Get all inputs and intermediates in one dictionary""" + data = {} + + # Check inputs + for input_param in self.inputs: + if input_param.name: + value = state.get_input(input_param.name) + if input_param.required and value is None: + raise ValueError(f"Required input '{input_param.name}' is missing") + elif value is not None or (value is None and input_param.name not in data): + data[input_param.name] = value + elif input_param.kwargs_type: + # if kwargs_type is provided, get all inputs with matching kwargs_type + if input_param.kwargs_type not in data: + data[input_param.kwargs_type] = {} + inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type) + if inputs_kwargs: + for k, v in inputs_kwargs.items(): + if v is not None: + data[k] = v + data[input_param.kwargs_type][k] = v + + # Check intermediates + for input_param in self.intermediates_inputs: + if input_param.name: + value = state.get_intermediate(input_param.name) + if input_param.required and value is None: + raise ValueError(f"Required intermediate input '{input_param.name}' is missing") + elif value is not None or (value is None and input_param.name not in data): + data[input_param.name] = value + elif input_param.kwargs_type: + # if kwargs_type is provided, get all intermediates with matching kwargs_type + if input_param.kwargs_type not in data: + data[input_param.kwargs_type] = {} + intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type) + if intermediates_kwargs: + for k, v in intermediates_kwargs.items(): + if v is not None: + if k not in data: + data[k] = v + data[input_param.kwargs_type][k] = v + return BlockState(**data) + + def add_block_state(self, state: PipelineState, block_state: BlockState): + for output_param in self.intermediates_outputs: + if not hasattr(block_state, output_param.name): + raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") + param = getattr(block_state, output_param.name) + state.add_intermediate(output_param.name, param, output_param.kwargs_type) # YiYi TODO: # 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 119c92e06f1d..7869e11a9cd5 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -370,10 +370,10 @@ def inputs(self) -> List[InputParam]: @property def intermediates_outputs(self) -> List[OutputParam]: return [ - OutputParam("prompt_embeds", type_hint=torch.Tensor, description="text embeddings used to guide the image generation"), - OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="negative text embeddings used to guide the image generation"), - OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, description="pooled text embeddings used to guide the image generation"), - OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="negative pooled text embeddings used to guide the image generation"), + OutputParam("prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields",description="text embeddings used to guide the image generation"), + OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative text embeddings used to guide the image generation"), + OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="pooled text embeddings used to guide the image generation"), + OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative pooled text embeddings used to guide the image generation"), ] @staticmethod @@ -982,12 +982,12 @@ def intermediates_outputs(self) -> List[str]: return [ OutputParam("batch_size", type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt"), OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs (determined by `prompt_embeds`)"), - OutputParam("prompt_embeds", type_hint=torch.Tensor, description="text embeddings used to guide the image generation"), - OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="negative text embeddings used to guide the image generation"), - OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, description="pooled text embeddings used to guide the image generation"), - OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="negative pooled text embeddings used to guide the image generation"), - OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="image embeddings for IP-Adapter"), - OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="negative image embeddings for IP-Adapter"), + OutputParam("prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="text embeddings used to guide the image generation"), + OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative text embeddings used to guide the image generation"), + OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="pooled text embeddings used to guide the image generation"), + OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative pooled text embeddings used to guide the image generation"), + OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], kwargs_type="guider_input_fields", description="image embeddings for IP-Adapter"), + OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], kwargs_type="guider_input_fields", description="negative image embeddings for IP-Adapter"), ] def check_inputs(self, components, block_state): @@ -1836,8 +1836,8 @@ def intermediates_inputs(self) -> List[InputParam]: @property def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("add_time_ids", type_hint=torch.Tensor, description="The time ids to condition the denoising process"), - OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="The negative time ids to condition the denoising process"), + return [OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The time ids to condition the denoising process"), + OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process"), OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids with self -> components @@ -2025,8 +2025,8 @@ def intermediates_inputs(self) -> List[InputParam]: @property def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("add_time_ids", type_hint=torch.Tensor, description="The time ids to condition the denoising process"), - OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="The negative time ids to condition the denoising process"), + return [OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The time ids to condition the denoising process"), + OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process"), OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids with self -> components @@ -2135,264 +2135,265 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt return components, state -class StableDiffusionXLDenoiseStep(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config"), - ComponentSpec("scheduler", EulerDiscreteScheduler), - ComponentSpec("unet", UNet2DConditionModel), - ] - - @property - def description(self) -> str: - return ( - "Step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process" - ) - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("cross_attention_kwargs"), - InputParam("generator"), - InputParam("eta", default=0.0), - InputParam("num_images_per_prompt", default=1), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." - ), - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), - InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "pooled_prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "negative_pooled_prompt_embeds", - type_hint=Optional[torch.Tensor], - description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " - ), - InputParam( - "add_time_ids", - required=True, - type_hint=torch.Tensor, - description="The time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." - ), - InputParam( - "negative_add_time_ids", - type_hint=Optional[torch.Tensor], - description="The negative time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." - ), - InputParam( - "prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "negative_prompt_embeds", - type_hint=Optional[torch.Tensor], - description="The negative prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " - ), - InputParam( - "timestep_cond", - type_hint=Optional[torch.Tensor], - description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." - ), - InputParam( - "mask", - type_hint=Optional[torch.Tensor], - description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "masked_image_latents", - type_hint=Optional[torch.Tensor], - description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "noise", - type_hint=Optional[torch.Tensor], - description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." - ), - InputParam( - "image_latents", - type_hint=Optional[torch.Tensor], - description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "ip_adapter_embeds", - type_hint=Optional[torch.Tensor], - description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." - ), - InputParam( - "negative_ip_adapter_embeds", - type_hint=Optional[torch.Tensor], - description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." - ), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - - - @staticmethod - def check_inputs(components, block_state): - - num_channels_unet = components.unet.config.in_channels - if num_channels_unet == 9: - # default case for runwayml/stable-diffusion-inpainting - if block_state.mask is None or block_state.masked_image_latents is None: - raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") - num_channels_latents = block_state.latents.shape[1] - num_channels_mask = block_state.mask.shape[1] - num_channels_masked_image = block_state.masked_image_latents.shape[1] - if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: - raise ValueError( - f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" - f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" - f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" - " `components.unet` or your `mask_image` or `image` input." - ) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components - @staticmethod - def prepare_extra_step_kwargs(components, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(components.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(components.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs +from .pipeline_stable_diffusion_xl_modular_denoise_loop import StableDiffusionXLDenoiseStep +# class StableDiffusionXLDenoiseStep(PipelineBlock): + +# model_name = "stable-diffusion-xl" + +# @property +# def expected_components(self) -> List[ComponentSpec]: +# return [ +# ComponentSpec( +# "guider", +# ClassifierFreeGuidance, +# config=FrozenDict({"guidance_scale": 7.5}), +# default_creation_method="from_config"), +# ComponentSpec("scheduler", EulerDiscreteScheduler), +# ComponentSpec("unet", UNet2DConditionModel), +# ] + +# @property +# def description(self) -> str: +# return ( +# "Step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process" +# ) + +# @property +# def inputs(self) -> List[Tuple[str, Any]]: +# return [ +# InputParam("cross_attention_kwargs"), +# InputParam("generator"), +# InputParam("eta", default=0.0), +# InputParam("num_images_per_prompt", default=1), +# ] + +# @property +# def intermediates_inputs(self) -> List[str]: +# return [ +# InputParam( +# "latents", +# required=True, +# type_hint=torch.Tensor, +# description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." +# ), +# InputParam( +# "batch_size", +# required=True, +# type_hint=int, +# description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." +# ), +# InputParam( +# "timesteps", +# required=True, +# type_hint=torch.Tensor, +# description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." +# ), +# InputParam( +# "num_inference_steps", +# required=True, +# type_hint=int, +# description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." +# ), +# InputParam( +# "pooled_prompt_embeds", +# required=True, +# type_hint=torch.Tensor, +# description="The pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "negative_pooled_prompt_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " +# ), +# InputParam( +# "add_time_ids", +# required=True, +# type_hint=torch.Tensor, +# description="The time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." +# ), +# InputParam( +# "negative_add_time_ids", +# type_hint=Optional[torch.Tensor], +# description="The negative time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." +# ), +# InputParam( +# "prompt_embeds", +# required=True, +# type_hint=torch.Tensor, +# description="The prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "negative_prompt_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " +# ), +# InputParam( +# "timestep_cond", +# type_hint=Optional[torch.Tensor], +# description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." +# ), +# InputParam( +# "mask", +# type_hint=Optional[torch.Tensor], +# description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "masked_image_latents", +# type_hint=Optional[torch.Tensor], +# description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "noise", +# type_hint=Optional[torch.Tensor], +# description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." +# ), +# InputParam( +# "image_latents", +# type_hint=Optional[torch.Tensor], +# description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "ip_adapter_embeds", +# type_hint=Optional[torch.Tensor], +# description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." +# ), +# InputParam( +# "negative_ip_adapter_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." +# ), +# ] + +# @property +# def intermediates_outputs(self) -> List[OutputParam]: +# return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + +# @staticmethod +# def check_inputs(components, block_state): + +# num_channels_unet = components.unet.config.in_channels +# if num_channels_unet == 9: +# # default case for runwayml/stable-diffusion-inpainting +# if block_state.mask is None or block_state.masked_image_latents is None: +# raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") +# num_channels_latents = block_state.latents.shape[1] +# num_channels_mask = block_state.mask.shape[1] +# num_channels_masked_image = block_state.masked_image_latents.shape[1] +# if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: +# raise ValueError( +# f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" +# f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" +# f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" +# f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" +# " `components.unet` or your `mask_image` or `image` input." +# ) + +# # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components +# @staticmethod +# def prepare_extra_step_kwargs(components, generator, eta): +# # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature +# # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. +# # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 +# # and should be between [0, 1] + +# accepts_eta = "eta" in set(inspect.signature(components.scheduler.step).parameters.keys()) +# extra_step_kwargs = {} +# if accepts_eta: +# extra_step_kwargs["eta"] = eta + +# # check if the scheduler accepts generator +# accepts_generator = "generator" in set(inspect.signature(components.scheduler.step).parameters.keys()) +# if accepts_generator: +# extra_step_kwargs["generator"] = generator +# return extra_step_kwargs - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - - block_state = self.get_block_state(state) - self.check_inputs(components, block_state) - - block_state.num_channels_unet = components.unet.config.in_channels - block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False - if block_state.disable_guidance: - components.guider.disable() - else: - components.guider.enable() - - # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) - block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) - - components.guider.set_input_fields( - prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), - add_time_ids=("add_time_ids", "negative_add_time_ids"), - pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), - ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), - ) - - with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: - for i, t in enumerate(block_state.timesteps): - components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) - guider_data = components.guider.prepare_inputs(block_state) - - block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) +# @torch.no_grad() +# def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + +# block_state = self.get_block_state(state) +# self.check_inputs(components, block_state) + +# block_state.num_channels_unet = components.unet.config.in_channels +# block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False +# if block_state.disable_guidance: +# components.guider.disable() +# else: +# components.guider.enable() + +# # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline +# block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) +# block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) + +# components.guider.set_input_fields( +# prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), +# add_time_ids=("add_time_ids", "negative_add_time_ids"), +# pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), +# ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), +# ) + +# with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: +# for i, t in enumerate(block_state.timesteps): +# components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) +# guider_data = components.guider.prepare_inputs(block_state) + +# block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) - # Prepare for inpainting - if block_state.num_channels_unet == 9: - block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) +# # Prepare for inpainting +# if block_state.num_channels_unet == 9: +# block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) - for batch in guider_data: - components.guider.prepare_models(components.unet) +# for batch in guider_data: +# components.guider.prepare_models(components.unet) - # Prepare additional conditionings - batch.added_cond_kwargs = { - "text_embeds": batch.pooled_prompt_embeds, - "time_ids": batch.add_time_ids, - } - if batch.ip_adapter_embeds is not None: - batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds +# # Prepare additional conditionings +# batch.added_cond_kwargs = { +# "text_embeds": batch.pooled_prompt_embeds, +# "time_ids": batch.add_time_ids, +# } +# if batch.ip_adapter_embeds is not None: +# batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds - # Predict the noise residual - batch.noise_pred = components.unet( - block_state.scaled_latents, - t, - encoder_hidden_states=batch.prompt_embeds, - timestep_cond=block_state.timestep_cond, - cross_attention_kwargs=block_state.cross_attention_kwargs, - added_cond_kwargs=batch.added_cond_kwargs, - return_dict=False, - )[0] - components.guider.cleanup_models(components.unet) - - # Perform guidance - block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_data) +# # Predict the noise residual +# batch.noise_pred = components.unet( +# block_state.scaled_latents, +# t, +# encoder_hidden_states=batch.prompt_embeds, +# timestep_cond=block_state.timestep_cond, +# cross_attention_kwargs=block_state.cross_attention_kwargs, +# added_cond_kwargs=batch.added_cond_kwargs, +# return_dict=False, +# )[0] +# components.guider.cleanup_models(components.unet) + +# # Perform guidance +# block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_data) - # Perform scheduler step using the predicted output - block_state.latents_dtype = block_state.latents.dtype - block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] - - if block_state.latents.dtype != block_state.latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - block_state.latents = block_state.latents.to(block_state.latents_dtype) +# # Perform scheduler step using the predicted output +# block_state.latents_dtype = block_state.latents.dtype +# block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] + +# if block_state.latents.dtype != block_state.latents_dtype: +# if torch.backends.mps.is_available(): +# # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 +# block_state.latents = block_state.latents.to(block_state.latents_dtype) - if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: - block_state.init_latents_proper = block_state.image_latents - if i < len(block_state.timesteps) - 1: - block_state.noise_timestep = block_state.timesteps[i + 1] - block_state.init_latents_proper = components.scheduler.add_noise( - block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) - ) +# if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: +# block_state.init_latents_proper = block_state.image_latents +# if i < len(block_state.timesteps) - 1: +# block_state.noise_timestep = block_state.timesteps[i + 1] +# block_state.init_latents_proper = components.scheduler.add_noise( +# block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) +# ) - block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents +# block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents - if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): - progress_bar.update() +# if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): +# progress_bar.update() - self.add_block_state(state, block_state) +# self.add_block_state(state, block_state) - return components, state +# return components, state class StableDiffusionXLControlNetInputStep(PipelineBlock): @@ -2452,11 +2453,11 @@ def intermediates_inputs(self) -> List[str]: @property def intermediates_outputs(self) -> List[OutputParam]: return [ - OutputParam("controlnet_cond", type_hint=torch.Tensor, description="The processed control image", kwargs_type="contronet_kwargs"), + OutputParam("controlnet_cond", type_hint=torch.Tensor, description="The processed control image"), OutputParam("control_guidance_start", type_hint=List[float], description="The controlnet guidance start values"), OutputParam("control_guidance_end", type_hint=List[float], description="The controlnet guidance end values"), OutputParam("conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values"), - OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used", kwargs_type="controlnet_kwargs"), + OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"), OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"), ] @@ -2592,353 +2593,353 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt return components, state -class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config"), - ComponentSpec("scheduler", EulerDiscreteScheduler), - ComponentSpec("unet", UNet2DConditionModel), - ComponentSpec("controlnet", ControlNetModel), - ] - - @property - def description(self) -> str: - return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("num_images_per_prompt", default=1), - InputParam("cross_attention_kwargs"), - InputParam("generator", kwargs_type="scheduler_kwargs"), - InputParam("eta", default=0.0, kwargs_type="scheduler_kwargs"), - InputParam("controlnet_conditioning_scale", type_hint=float, default=1.0), # can expect either input or intermediate input, (intermediate input if both are passed) - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "controlnet_cond", - required=True, - type_hint=torch.Tensor, - description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step." - ), - InputParam( - "control_guidance_start", - required=True, - type_hint=float, - description="The control guidance start value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." - ), - InputParam( - "control_guidance_end", - required=True, - type_hint=float, - description="The control guidance end value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." - ), - InputParam( - "conditioning_scale", - type_hint=float, - description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." - ), - InputParam( - "guess_mode", - required=True, - type_hint=bool, - description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." - ), - InputParam( - "controlnet_keep", - required=True, - type_hint=List[float], - description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step." - ), - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." - ), - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), - InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "negative_prompt_embeds", - type_hint=Optional[torch.Tensor], - description="The negative prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "add_time_ids", - required=True, - type_hint=torch.Tensor, - description="The time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." - ), - InputParam( - "negative_add_time_ids", - type_hint=Optional[torch.Tensor], - description="The negative time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." - ), - InputParam( - "pooled_prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The pooled prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "negative_pooled_prompt_embeds", - type_hint=Optional[torch.Tensor], - description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "timestep_cond", - type_hint=Optional[torch.Tensor], - description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" - ), - InputParam( - "mask", - type_hint=Optional[torch.Tensor], - description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "masked_image_latents", - type_hint=Optional[torch.Tensor], - description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "noise", - type_hint=Optional[torch.Tensor], - description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." - ), - InputParam( - "image_latents", - type_hint=Optional[torch.Tensor], - description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "crops_coords", - type_hint=Optional[Tuple[int]], - description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." - ), - InputParam( - "ip_adapter_embeds", - type_hint=Optional[torch.Tensor], - description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." - ), - InputParam( - "negative_ip_adapter_embeds", - type_hint=Optional[torch.Tensor], - description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." - ), - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam(kwargs_type="controlnet_kwargs", description="additional kwargs for controlnet") - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - - @staticmethod - def check_inputs(components, block_state): - - num_channels_unet = components.unet.config.in_channels - if num_channels_unet == 9: - # default case for runwayml/stable-diffusion-inpainting - if block_state.mask is None or block_state.masked_image_latents is None: - raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") - num_channels_latents = block_state.latents.shape[1] - num_channels_mask = block_state.mask.shape[1] - num_channels_masked_image = block_state.masked_image_latents.shape[1] - if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: - raise ValueError( - f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" - f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" - f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" - " `components.unet` or your `mask_image` or `image` input." - ) - @staticmethod - def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): - - accepted_kwargs = set(inspect.signature(func).parameters.keys()) - extra_kwargs = {} - for key, value in kwargs.items(): - if key in accepted_kwargs and key not in exclude_kwargs: - extra_kwargs[key] = value - - return extra_kwargs - - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_modular_denoise_loop import StableDiffusionXLControlNetDenoiseStep +# class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): + +# model_name = "stable-diffusion-xl" + +# @property +# def expected_components(self) -> List[ComponentSpec]: +# return [ +# ComponentSpec( +# "guider", +# ClassifierFreeGuidance, +# config=FrozenDict({"guidance_scale": 7.5}), +# default_creation_method="from_config"), +# ComponentSpec("scheduler", EulerDiscreteScheduler), +# ComponentSpec("unet", UNet2DConditionModel), +# ComponentSpec("controlnet", ControlNetModel), +# ] + +# @property +# def description(self) -> str: +# return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" + +# @property +# def inputs(self) -> List[Tuple[str, Any]]: +# return [ +# InputParam("num_images_per_prompt", default=1), +# InputParam("cross_attention_kwargs"), +# InputParam("generator"), +# InputParam("eta", default=0.0), +# InputParam("controlnet_conditioning_scale", type_hint=float, default=1.0), # can expect either input or intermediate input, (intermediate input if both are passed) +# ] + +# @property +# def intermediates_inputs(self) -> List[str]: +# return [ +# InputParam( +# "controlnet_cond", +# required=True, +# type_hint=torch.Tensor, +# description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "control_guidance_start", +# required=True, +# type_hint=float, +# description="The control guidance start value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "control_guidance_end", +# required=True, +# type_hint=float, +# description="The control guidance end value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "conditioning_scale", +# type_hint=float, +# description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "guess_mode", +# required=True, +# type_hint=bool, +# description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "controlnet_keep", +# required=True, +# type_hint=List[float], +# description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "latents", +# required=True, +# type_hint=torch.Tensor, +# description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." +# ), +# InputParam( +# "batch_size", +# required=True, +# type_hint=int, +# description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." +# ), +# InputParam( +# "timesteps", +# required=True, +# type_hint=torch.Tensor, +# description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." +# ), +# InputParam( +# "prompt_embeds", +# required=True, +# type_hint=torch.Tensor, +# description="The prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "negative_prompt_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "add_time_ids", +# required=True, +# type_hint=torch.Tensor, +# description="The time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." +# ), +# InputParam( +# "negative_add_time_ids", +# type_hint=Optional[torch.Tensor], +# description="The negative time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." +# ), +# InputParam( +# "pooled_prompt_embeds", +# required=True, +# type_hint=torch.Tensor, +# description="The pooled prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "negative_pooled_prompt_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "timestep_cond", +# type_hint=Optional[torch.Tensor], +# description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" +# ), +# InputParam( +# "mask", +# type_hint=Optional[torch.Tensor], +# description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "masked_image_latents", +# type_hint=Optional[torch.Tensor], +# description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "noise", +# type_hint=Optional[torch.Tensor], +# description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." +# ), +# InputParam( +# "image_latents", +# type_hint=Optional[torch.Tensor], +# description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "crops_coords", +# type_hint=Optional[Tuple[int]], +# description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." +# ), +# InputParam( +# "ip_adapter_embeds", +# type_hint=Optional[torch.Tensor], +# description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." +# ), +# InputParam( +# "negative_ip_adapter_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." +# ), +# InputParam( +# "num_inference_steps", +# required=True, +# type_hint=int, +# description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." +# ), +# InputParam(kwargs_type="controlnet_kwargs", description="additional kwargs for controlnet") +# ] + +# @property +# def intermediates_outputs(self) -> List[OutputParam]: +# return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + +# @staticmethod +# def check_inputs(components, block_state): + +# num_channels_unet = components.unet.config.in_channels +# if num_channels_unet == 9: +# # default case for runwayml/stable-diffusion-inpainting +# if block_state.mask is None or block_state.masked_image_latents is None: +# raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") +# num_channels_latents = block_state.latents.shape[1] +# num_channels_mask = block_state.mask.shape[1] +# num_channels_masked_image = block_state.masked_image_latents.shape[1] +# if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: +# raise ValueError( +# f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" +# f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" +# f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" +# f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" +# " `components.unet` or your `mask_image` or `image` input." +# ) +# @staticmethod +# def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + +# accepted_kwargs = set(inspect.signature(func).parameters.keys()) +# extra_kwargs = {} +# for key, value in kwargs.items(): +# if key in accepted_kwargs and key not in exclude_kwargs: +# extra_kwargs[key] = value + +# return extra_kwargs + + +# @torch.no_grad() +# def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - self.check_inputs(components, block_state) - block_state.device = components._execution_device - print(f" block_state: {block_state}") +# block_state = self.get_block_state(state) +# self.check_inputs(components, block_state) +# block_state.device = components._execution_device +# print(f" block_state: {block_state}") - controlnet = unwrap_module(components.controlnet) +# controlnet = unwrap_module(components.controlnet) - # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - # YiYI TODO: refactor scheduler_kwargs and support unet kwargs - block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) - block_state.extra_controlnet_kwargs = self.prepare_extra_kwargs(controlnet.forward, exclude_kwargs=["controlnet_cond", "conditioning_scale", "guess_mode"], **block_state.controlnet_kwargs) +# # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline +# block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) +# block_state.extra_controlnet_kwargs = self.prepare_extra_kwargs(controlnet.forward, exclude_kwargs=["controlnet_cond", "conditioning_scale", "guess_mode"], **block_state.controlnet_kwargs) - block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) +# block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) - # (1) setup guider - # disable for LCMs - block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False - if block_state.disable_guidance: - components.guider.disable() - else: - components.guider.enable() - components.guider.set_input_fields( - prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), - add_time_ids=("add_time_ids", "negative_add_time_ids"), - pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), - ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), - ) - - # (5) Denoise loop - with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: - for i, t in enumerate(block_state.timesteps): - - # prepare latent input for unet - block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) - # adjust latent input for inpainting - block_state.num_channels_unet = components.unet.config.in_channels - if block_state.num_channels_unet == 9: - block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) - - - # cond_scale (controlnet input) - if isinstance(block_state.controlnet_keep[i], list): - block_state.cond_scale = [c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i])] - else: - block_state.controlnet_cond_scale = block_state.conditioning_scale - if isinstance(block_state.controlnet_cond_scale, list): - block_state.controlnet_cond_scale = block_state.controlnet_cond_scale[0] - block_state.cond_scale = block_state.controlnet_cond_scale * block_state.controlnet_keep[i] +# # (1) setup guider +# # disable for LCMs +# block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False +# if block_state.disable_guidance: +# components.guider.disable() +# else: +# components.guider.enable() +# components.guider.set_input_fields( +# prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), +# add_time_ids=("add_time_ids", "negative_add_time_ids"), +# pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), +# ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), +# ) + +# # (5) Denoise loop +# with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: +# for i, t in enumerate(block_state.timesteps): + +# # prepare latent input for unet +# block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) +# # adjust latent input for inpainting +# block_state.num_channels_unet = components.unet.config.in_channels +# if block_state.num_channels_unet == 9: +# block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) + + +# # cond_scale (controlnet input) +# if isinstance(block_state.controlnet_keep[i], list): +# block_state.cond_scale = [c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i])] +# else: +# block_state.controlnet_cond_scale = block_state.conditioning_scale +# if isinstance(block_state.controlnet_cond_scale, list): +# block_state.controlnet_cond_scale = block_state.controlnet_cond_scale[0] +# block_state.cond_scale = block_state.controlnet_cond_scale * block_state.controlnet_keep[i] - # default controlnet output/unet input for guess mode + conditional path - block_state.down_block_res_samples_zeros = None - block_state.mid_block_res_sample_zeros = None +# # default controlnet output/unet input for guess mode + conditional path +# block_state.down_block_res_samples_zeros = None +# block_state.mid_block_res_sample_zeros = None - # guided denoiser step - components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) - guider_state = components.guider.prepare_inputs(block_state) +# # guided denoiser step +# components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) +# guider_state = components.guider.prepare_inputs(block_state) - for guider_state_batch in guider_state: - components.guider.prepare_models(components.unet) +# for guider_state_batch in guider_state: +# components.guider.prepare_models(components.unet) - # Prepare additional conditionings - guider_state_batch.added_cond_kwargs = { - "text_embeds": guider_state_batch.pooled_prompt_embeds, - "time_ids": guider_state_batch.add_time_ids, - } - if guider_state_batch.ip_adapter_embeds is not None: - guider_state_batch.added_cond_kwargs["image_embeds"] = guider_state_batch.ip_adapter_embeds +# # Prepare additional conditionings +# guider_state_batch.added_cond_kwargs = { +# "text_embeds": guider_state_batch.pooled_prompt_embeds, +# "time_ids": guider_state_batch.add_time_ids, +# } +# if guider_state_batch.ip_adapter_embeds is not None: +# guider_state_batch.added_cond_kwargs["image_embeds"] = guider_state_batch.ip_adapter_embeds - # Prepare controlnet additional conditionings - guider_state_batch.controlnet_added_cond_kwargs = { - "text_embeds": guider_state_batch.pooled_prompt_embeds, - "time_ids": guider_state_batch.add_time_ids, - } - - if block_state.guess_mode and not components.guider.is_conditional: - # guider always run uncond batch first, so these tensors should be set already - guider_state_batch.down_block_res_samples = block_state.down_block_res_samples_zeros - guider_state_batch.mid_block_res_sample = block_state.mid_block_res_sample_zeros - else: - guider_state_batch.down_block_res_samples, guider_state_batch.mid_block_res_sample = components.controlnet( - block_state.scaled_latents, - t, - encoder_hidden_states=guider_state_batch.prompt_embeds, - controlnet_cond=block_state.controlnet_cond, - conditioning_scale=block_state.conditioning_scale, - guess_mode=block_state.guess_mode, - added_cond_kwargs=guider_state_batch.controlnet_added_cond_kwargs, - return_dict=False, - **block_state.extra_controlnet_kwargs, - ) +# # Prepare controlnet additional conditionings +# guider_state_batch.controlnet_added_cond_kwargs = { +# "text_embeds": guider_state_batch.pooled_prompt_embeds, +# "time_ids": guider_state_batch.add_time_ids, +# } + +# if block_state.guess_mode and not components.guider.is_conditional: +# # guider always run uncond batch first, so these tensors should be set already +# guider_state_batch.down_block_res_samples = block_state.down_block_res_samples_zeros +# guider_state_batch.mid_block_res_sample = block_state.mid_block_res_sample_zeros +# else: +# guider_state_batch.down_block_res_samples, guider_state_batch.mid_block_res_sample = components.controlnet( +# block_state.scaled_latents, +# t, +# encoder_hidden_states=guider_state_batch.prompt_embeds, +# controlnet_cond=block_state.controlnet_cond, +# conditioning_scale=block_state.conditioning_scale, +# guess_mode=block_state.guess_mode, +# added_cond_kwargs=guider_state_batch.controlnet_added_cond_kwargs, +# return_dict=False, +# **block_state.extra_controlnet_kwargs, +# ) - if block_state.down_block_res_samples_zeros is None: - block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in guider_state_batch.down_block_res_samples] - if block_state.mid_block_res_sample_zeros is None: - block_state.mid_block_res_sample_zeros = torch.zeros_like(guider_state_batch.mid_block_res_sample) +# if block_state.down_block_res_samples_zeros is None: +# block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in guider_state_batch.down_block_res_samples] +# if block_state.mid_block_res_sample_zeros is None: +# block_state.mid_block_res_sample_zeros = torch.zeros_like(guider_state_batch.mid_block_res_sample) - guider_state_batch.noise_pred = components.unet( - block_state.scaled_latents, - t, - encoder_hidden_states=guider_state_batch.prompt_embeds, - timestep_cond=block_state.timestep_cond, - cross_attention_kwargs=block_state.cross_attention_kwargs, - added_cond_kwargs=guider_state_batch.added_cond_kwargs, - down_block_additional_residuals=guider_state_batch.down_block_res_samples, - mid_block_additional_residual=guider_state_batch.mid_block_res_sample, - return_dict=False, - )[0] - components.guider.cleanup_models(components.unet) +# guider_state_batch.noise_pred = components.unet( +# block_state.scaled_latents, +# t, +# encoder_hidden_states=guider_state_batch.prompt_embeds, +# timestep_cond=block_state.timestep_cond, +# cross_attention_kwargs=block_state.cross_attention_kwargs, +# added_cond_kwargs=guider_state_batch.added_cond_kwargs, +# down_block_additional_residuals=guider_state_batch.down_block_res_samples, +# mid_block_additional_residual=guider_state_batch.mid_block_res_sample, +# return_dict=False, +# )[0] +# components.guider.cleanup_models(components.unet) - # Perform guidance - block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_state) +# # Perform guidance +# block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_state) - # Perform scheduler step using the predicted output - block_state.latents_dtype = block_state.latents.dtype - block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] +# # Perform scheduler step using the predicted output +# block_state.latents_dtype = block_state.latents.dtype +# block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] - if block_state.latents.dtype != block_state.latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - block_state.latents = block_state.latents.to(block_state.latents_dtype) +# if block_state.latents.dtype != block_state.latents_dtype: +# if torch.backends.mps.is_available(): +# # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 +# block_state.latents = block_state.latents.to(block_state.latents_dtype) - # adjust latent for inpainting - if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: - block_state.init_latents_proper = block_state.image_latents - if i < len(block_state.timesteps) - 1: - block_state.noise_timestep = block_state.timesteps[i + 1] - block_state.init_latents_proper = components.scheduler.add_noise( - block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) - ) - - block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents - - if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): - progress_bar.update() +# # adjust latent for inpainting +# if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: +# block_state.init_latents_proper = block_state.image_latents +# if i < len(block_state.timesteps) - 1: +# block_state.noise_timestep = block_state.timesteps[i + 1] +# block_state.init_latents_proper = components.scheduler.add_noise( +# block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) +# ) + +# block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents + +# if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): +# progress_bar.update() - self.add_block_state(state, block_state) +# self.add_block_state(state, block_state) - return components, state +# return components, state class StableDiffusionXLControlNetUnionInputStep(PipelineBlock): @@ -3004,13 +3005,13 @@ def intermediates_inputs(self) -> List[InputParam]: @property def intermediates_outputs(self) -> List[OutputParam]: return [ - OutputParam("controlnet_cond", type_hint=List[torch.Tensor], description="The processed control images", kwargs_type="controlnet_kwargs"), + OutputParam("controlnet_cond", type_hint=List[torch.Tensor], description="The processed control images"), OutputParam("control_type_idx", type_hint=List[int], description="The control mode indices", kwargs_type="controlnet_kwargs"), OutputParam("control_type", type_hint=torch.Tensor, description="The control type tensor that specifies which control type is active", kwargs_type="controlnet_kwargs"), OutputParam("control_guidance_start", type_hint=float, description="The controlnet guidance start value"), OutputParam("control_guidance_end", type_hint=float, description="The controlnet guidance end value"), OutputParam("conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values"), - OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used", kwargs_type="controlnet_kwargs"), + OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"), OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"), ] diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular_denoise_loop.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular_denoise_loop.py new file mode 100644 index 000000000000..92c07854fc74 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular_denoise_loop.py @@ -0,0 +1,729 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from tqdm.auto import tqdm + +from ...configuration_utils import FrozenDict +from ...models import ControlNetModel, UNet2DConditionModel +from ...schedulers import EulerDiscreteScheduler +from ...utils import logging +from ...utils.torch_utils import unwrap_module +from ..modular_pipeline import ( + PipelineBlock, + PipelineState, + LoopSequentialPipelineBlocks, + InputParam, + OutputParam, + BlockState, + ComponentSpec, +) +from ...guiders import ClassifierFreeGuidance +from .pipeline_stable_diffusion_xl_modular import StableDiffusionXLModularLoader +from dataclasses import asdict + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + + +# YiYi experimenting composible denoise loop +# loop step (1): prepare latent input for denoiser +class StableDiffusionXLDenoiseLoopLatentsStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return "step within the denoising loop that prepare the latent input for the denoiser" + + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("scaled_latents", type_hint=torch.Tensor, description="The scaled latents input for denoiser")] + + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + + block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) + + + return components, block_state + +# loop step (1): prepare latent input for denoiser (with inpainting) +class StableDiffusionXLDenoiseLoopInpaintLatentsStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def description(self) -> str: + return "step within the denoising loop that prepare the latent input for the denoiser" + + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + InputParam( + "mask", + type_hint=Optional[torch.Tensor], + description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." + ), + InputParam( + "masked_image_latents", + type_hint=Optional[torch.Tensor], + description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("scaled_latents", type_hint=torch.Tensor, description="The scaled latents input for denoiser")] + + @staticmethod + def check_inputs(components, block_state): + + num_channels_unet = components.num_channels_unet + if num_channels_unet == 9: + # default case for runwayml/stable-diffusion-inpainting + if block_state.mask is None or block_state.masked_image_latents is None: + raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") + num_channels_latents = block_state.latents.shape[1] + num_channels_mask = block_state.mask.shape[1] + num_channels_masked_image = block_state.masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: + raise ValueError( + f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" + f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `components.unet` or your `mask_image` or `image` input." + ) + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, loop_idx: int, t: int): + + self.check_inputs(components, block_state) + + block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) + if components.num_channels_unet == 9: + block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) + + + return components, block_state + +# loop step (2): denoise the latents with guidance +class StableDiffusionXLDenoiseLoopDenoiserStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoise the latents with guidance" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("cross_attention_kwargs"), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "scaled_latents", + required=True, + type_hint=torch.Tensor, + description="The prepared latents input to use for the denoiser. Can be generated in latent step within the denoise loop." + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + "timestep_cond", + type_hint=Optional[torch.Tensor], + description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." + ), + InputParam( + kwargs_type="guider_input_fields", + description=( + "All conditional model inputs that need to be prepared with guider. " + "It should contain prompt_embeds/negative_prompt_embeds, " + "add_time_ids/negative_add_time_ids, " + "pooled_prompt_embeds/negative_pooled_prompt_embeds, " + "and ip_adapter_embeds/negative_ip_adapter_embeds (optional)." + "please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" + ) + ), + + ] + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int) -> PipelineState: + + # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) + # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) + guider_input_fields ={ + "prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"), + "time_ids": ("add_time_ids", "negative_add_time_ids"), + "text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), + "image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"), + } + + + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + + # Prepare mini‐batches according to guidance method and `guider_input_fields` + # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds. + # e.g. for CFG, we prepare two batches: one for uncond, one for cond + # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds + # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds + guider_state = components.guider.prepare_inputs(block_state, guider_input_fields) + + # run the denoiser for each guidance batch + for guider_state_batch in guider_state: + components.guider.prepare_models(components.unet) + cond_kwargs = guider_state_batch.as_dict() + cond_kwargs = {k:v for k,v in cond_kwargs.items() if k in guider_input_fields} + prompt_embeds = cond_kwargs.pop("prompt_embeds") + + # Predict the noise residual + # store the noise_pred in guider_state_batch so that we can apply guidance across all batches + guider_state_batch.noise_pred = components.unet( + block_state.scaled_latents, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=block_state.timestep_cond, + cross_attention_kwargs=block_state.cross_attention_kwargs, + added_cond_kwargs=cond_kwargs, + return_dict=False, + )[0] + components.guider.cleanup_models(components.unet) + + # Perform guidance + block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state) + + return components, block_state + +# loop step (2): denoise the latents with guidance (with controlnet) +class StableDiffusionXLDenoiseLoopControlNetDenoiserStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ComponentSpec("unet", UNet2DConditionModel), + ComponentSpec("controlnet", ControlNetModel), + ] + + @property + def description(self) -> str: + return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("cross_attention_kwargs"), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "controlnet_cond", + required=True, + type_hint=torch.Tensor, + description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "conditioning_scale", + type_hint=float, + description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "guess_mode", + required=True, + type_hint=bool, + description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "controlnet_keep", + required=True, + type_hint=List[float], + description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "scaled_latents", + required=True, + type_hint=torch.Tensor, + description="The prepared latents input to use for the denoiser. Can be generated in latent step within the denoise loop." + ), + InputParam( + "timestep_cond", + type_hint=Optional[torch.Tensor], + description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + kwargs_type="guider_input_fields", + description=( + "All conditional model inputs that need to be prepared with guider. " + "It should contain prompt_embeds/negative_prompt_embeds, " + "add_time_ids/negative_add_time_ids, " + "pooled_prompt_embeds/negative_pooled_prompt_embeds, " + "and ip_adapter_embeds/negative_ip_adapter_embeds (optional)." + "please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" + ) + ), + InputParam( + kwargs_type="controlnet_kwargs", + description=( + "additional kwargs for controlnet (e.g. control_type_idx and control_type from the controlnet union input step )" + "please add `kwargs_type=controlnet_kwargs` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" + ) + ) + ] + + @staticmethod + def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + + accepted_kwargs = set(inspect.signature(func).parameters.keys()) + extra_kwargs = {} + for key, value in kwargs.items(): + if key in accepted_kwargs and key not in exclude_kwargs: + extra_kwargs[key] = value + + return extra_kwargs + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + + extra_controlnet_kwargs = self.prepare_extra_kwargs(components.controlnet.forward, **block_state.controlnet_kwargs) + + # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) + # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) + guider_input_fields ={ + "prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"), + "time_ids": ("add_time_ids", "negative_add_time_ids"), + "text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), + "image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"), + } + + + # cond_scale for the timestep (controlnet input) + if isinstance(block_state.controlnet_keep[i], list): + block_state.cond_scale = [c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i])] + else: + controlnet_cond_scale = block_state.conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + block_state.cond_scale = controlnet_cond_scale * block_state.controlnet_keep[i] + + # default controlnet output/unet input for guess mode + conditional path + block_state.down_block_res_samples_zeros = None + block_state.mid_block_res_sample_zeros = None + + # guided denoiser step + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + + # Prepare mini‐batches according to guidance method and `guider_input_fields` + # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds. + # e.g. for CFG, we prepare two batches: one for uncond, one for cond + # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds + # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds + guider_state = components.guider.prepare_inputs(block_state, guider_input_fields) + + # run the denoiser for each guidance batch + for guider_state_batch in guider_state: + components.guider.prepare_models(components.unet) + + # Prepare additional conditionings + added_cond_kwargs = { + "text_embeds": guider_state_batch.text_embeds, + "time_ids": guider_state_batch.time_ids, + } + if hasattr(guider_state_batch, "image_embeds") and guider_state_batch.image_embeds is not None: + added_cond_kwargs["image_embeds"] = guider_state_batch.image_embeds + + # Prepare controlnet additional conditionings + controlnet_added_cond_kwargs = { + "text_embeds": guider_state_batch.text_embeds, + "time_ids": guider_state_batch.time_ids, + } + # run controlnet for the guidance batch + if block_state.guess_mode and not components.guider.is_conditional: + # guider always run uncond batch first, so these tensors should be set already + down_block_res_samples = block_state.down_block_res_samples_zeros + mid_block_res_sample = block_state.mid_block_res_sample_zeros + else: + down_block_res_samples, mid_block_res_sample = components.controlnet( + block_state.scaled_latents, + t, + encoder_hidden_states=guider_state_batch.prompt_embeds, + controlnet_cond=block_state.controlnet_cond, + conditioning_scale=block_state.cond_scale, + guess_mode=block_state.guess_mode, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + **extra_controlnet_kwargs, + ) + + # assign it to block_state so it will be available for the uncond guidance batch + if block_state.down_block_res_samples_zeros is None: + block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in down_block_res_samples] + if block_state.mid_block_res_sample_zeros is None: + block_state.mid_block_res_sample_zeros = torch.zeros_like(mid_block_res_sample) + + # Predict the noise + # store the noise_pred in guider_state_batch so we can apply guidance across all batches + guider_state_batch.noise_pred = components.unet( + block_state.scaled_latents, + t, + encoder_hidden_states=guider_state_batch.prompt_embeds, + timestep_cond=block_state.timestep_cond, + cross_attention_kwargs=block_state.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + return_dict=False, + )[0] + components.guider.cleanup_models(components.unet) + + # Perform guidance + block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state) + + return components, block_state + +# loop step (3): scheduler step to update latents +class StableDiffusionXLDenoiseLoopUpdateLatentsStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("generator"), + InputParam("eta", default=0.0), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + #YiYi TODO: move this out of here + @staticmethod + def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + + accepted_kwargs = set(inspect.signature(func).parameters.keys()) + extra_kwargs = {} + for key, value in kwargs.items(): + if key in accepted_kwargs and key not in exclude_kwargs: + extra_kwargs[key] = value + + return extra_kwargs + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + + # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) + + + # Perform scheduler step using the predicted output + block_state.latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **block_state.scheduler_step_kwargs, return_dict=False)[0] + + if block_state.latents.dtype != block_state.latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + block_state.latents = block_state.latents.to(block_state.latents_dtype) + + return components, block_state + + +class StableDiffusionXLDenoiseLoopInpaintUpdateLatentsStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def description(self) -> str: + return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("generator"), + InputParam("eta", default=0.0), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + "mask", + type_hint=Optional[torch.Tensor], + description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." + ), + InputParam( + "noise", + type_hint=Optional[torch.Tensor], + description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." + ), + InputParam( + "image_latents", + type_hint=Optional[torch.Tensor], + description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + @staticmethod + def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + + accepted_kwargs = set(inspect.signature(func).parameters.keys()) + extra_kwargs = {} + for key, value in kwargs.items(): + if key in accepted_kwargs and key not in exclude_kwargs: + extra_kwargs[key] = value + + return extra_kwargs + + def check_inputs(self, components, block_state): + if components.num_channels_unet == 4: + if block_state.image_latents is None: + raise ValueError(f"image_latents is required for this step {self.__class__.__name__}") + if block_state.mask is None: + raise ValueError(f"mask is required for this step {self.__class__.__name__}") + if block_state.noise is None: + raise ValueError(f"noise is required for this step {self.__class__.__name__}") + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + + self.check_inputs(components, block_state) + + # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) + + + # Perform scheduler step using the predicted output + block_state.latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **block_state.scheduler_step_kwargs, return_dict=False)[0] + + if block_state.latents.dtype != block_state.latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + block_state.latents = block_state.latents.to(block_state.latents_dtype) + + # adjust latent for inpainting + if components.num_channels_unet == 4: + block_state.init_latents_proper = block_state.image_latents + if i < len(block_state.timesteps) - 1: + block_state.noise_timestep = block_state.timesteps[i + 1] + block_state.init_latents_proper = components.scheduler.add_noise( + block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) + ) + + block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents + + + + return components, block_state + + +# the loop wrapper that iterates over the timesteps +class StableDiffusionXLDenoiseLoop(LoopSequentialPipelineBlocks): + + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return ( + "Step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process" + ) + + @property + def loop_expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def loop_intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." + ), + ] + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False + if block_state.disable_guidance: + components.guider.disable() + else: + components.guider.enable() + + block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) + + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components, block_state = self.loop_step(components, block_state, i=i, t=t) + if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): + progress_bar.update() + + self.add_block_state(state, block_state) + + return components, state + + + +# StableDiffusionXLControlNetDenoiseStep + +class StableDiffusionXLDenoiseStep(StableDiffusionXLDenoiseLoop): + block_classes = [StableDiffusionXLDenoiseLoopLatentsStep, StableDiffusionXLDenoiseLoopDenoiserStep, StableDiffusionXLDenoiseLoopUpdateLatentsStep] + block_names = ["prepare_latents", "denoiser", "update_latents"] + +class StableDiffusionXLControlNetDenoiseStep(StableDiffusionXLDenoiseLoop): + block_classes = [StableDiffusionXLDenoiseLoopLatentsStep, StableDiffusionXLDenoiseLoopControlNetDenoiserStep, StableDiffusionXLDenoiseLoopUpdateLatentsStep] + block_names = ["prepare_latents", "denoiser", "update_latents"] + +class StableDiffusionXLInpaintDenoiseStep(StableDiffusionXLDenoiseLoop): + block_classes = [StableDiffusionXLDenoiseLoopInpaintLatentsStep, StableDiffusionXLDenoiseLoopDenoiserStep, StableDiffusionXLDenoiseLoopInpaintUpdateLatentsStep] + block_names = ["prepare_latents", "denoiser", "update_latents"] + +class StableDiffusionXLInpaintControlNetDenoiseStep(StableDiffusionXLDenoiseLoop): + block_classes = [StableDiffusionXLDenoiseLoopInpaintLatentsStep, StableDiffusionXLDenoiseLoopControlNetDenoiserStep, StableDiffusionXLDenoiseLoopInpaintUpdateLatentsStep] + block_names = ["prepare_latents", "denoiser", "update_latents"] + + + From c677d528e4c1c33b6c73c549c7a5f74ab8635f5e Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 9 May 2025 08:16:24 +0200 Subject: [PATCH 11/38] change warning to debug --- src/diffusers/guiders/guider_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index 18c85f579424..df544c955f33 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -186,7 +186,7 @@ def _prepare_batch(cls, input_fields: Dict[str, Union[str, Tuple[str, str]]], da # We've already checked that value is a string or a tuple of strings with length 2 pass except AttributeError: - logger.warning(f"`data` does not have attribute(s) {value}, skipping.") + logger.debug(f"`data` does not have attribute(s) {value}, skipping.") data_batch[cls._identifier_key] = identifier return BlockState(**data_batch) From 2b361a24132045786b229c1a6bfc3be0bd79e8a1 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 9 May 2025 08:17:10 +0200 Subject: [PATCH 12/38] fix get_execusion blocks with loopsequential --- src/diffusers/pipelines/modular_pipeline.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index 92cb50a8b490..97a8677bda63 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -1033,16 +1033,17 @@ def trigger_inputs(self): def _traverse_trigger_blocks(self, trigger_inputs): # Convert trigger_inputs to a set for easier manipulation active_triggers = set(trigger_inputs) - def fn_recursive_traverse(block, block_name, active_triggers): result_blocks = OrderedDict() - # sequential or PipelineBlock + # sequential(include loopsequential) or PipelineBlock if not hasattr(block, 'block_trigger_inputs'): if hasattr(block, 'blocks'): - # sequential - for block_name, block in block.blocks.items(): - blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers) + # sequential or LoopSequentialPipelineBlocks (keep traversing) + for sub_block_name, sub_block in block.blocks.items(): + blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers) + blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers) + blocks_to_update = {f"{block_name}.{k}": v for k,v in blocks_to_update.items()} result_blocks.update(blocks_to_update) else: # PipelineBlock @@ -1069,13 +1070,14 @@ def fn_recursive_traverse(block, block_name, active_triggers): matching_trigger = None if this_block is not None: - # sequential/auto + # sequential/auto (keep traversing) if hasattr(this_block, 'blocks'): result_blocks.update(fn_recursive_traverse(this_block, block_name, active_triggers)) else: # PipelineBlock result_blocks[block_name] = this_block # Add this block's output names to active triggers if defined + # YiYi TODO: do we need outputs here? can it just be intermediate_outputs? can we get rid of outputs attribute? if hasattr(this_block, 'outputs'): active_triggers.update(out.name for out in this_block.outputs) From 2017ae56244f87fb2137888cb440afb1c7a87663 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 9 May 2025 08:19:24 +0200 Subject: [PATCH 13/38] fix auto denoise so all tests pass --- .../pipeline_stable_diffusion_xl_modular.py | 699 +----------------- ...table_diffusion_xl_modular_denoise_loop.py | 688 ++++++++++++++++- 2 files changed, 702 insertions(+), 685 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 7869e11a9cd5..acb395345086 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -2134,268 +2134,6 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt self.add_block_state(state, block_state) return components, state - -from .pipeline_stable_diffusion_xl_modular_denoise_loop import StableDiffusionXLDenoiseStep -# class StableDiffusionXLDenoiseStep(PipelineBlock): - -# model_name = "stable-diffusion-xl" - -# @property -# def expected_components(self) -> List[ComponentSpec]: -# return [ -# ComponentSpec( -# "guider", -# ClassifierFreeGuidance, -# config=FrozenDict({"guidance_scale": 7.5}), -# default_creation_method="from_config"), -# ComponentSpec("scheduler", EulerDiscreteScheduler), -# ComponentSpec("unet", UNet2DConditionModel), -# ] - -# @property -# def description(self) -> str: -# return ( -# "Step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process" -# ) - -# @property -# def inputs(self) -> List[Tuple[str, Any]]: -# return [ -# InputParam("cross_attention_kwargs"), -# InputParam("generator"), -# InputParam("eta", default=0.0), -# InputParam("num_images_per_prompt", default=1), -# ] - -# @property -# def intermediates_inputs(self) -> List[str]: -# return [ -# InputParam( -# "latents", -# required=True, -# type_hint=torch.Tensor, -# description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." -# ), -# InputParam( -# "batch_size", -# required=True, -# type_hint=int, -# description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." -# ), -# InputParam( -# "timesteps", -# required=True, -# type_hint=torch.Tensor, -# description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." -# ), -# InputParam( -# "num_inference_steps", -# required=True, -# type_hint=int, -# description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." -# ), -# InputParam( -# "pooled_prompt_embeds", -# required=True, -# type_hint=torch.Tensor, -# description="The pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "negative_pooled_prompt_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " -# ), -# InputParam( -# "add_time_ids", -# required=True, -# type_hint=torch.Tensor, -# description="The time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." -# ), -# InputParam( -# "negative_add_time_ids", -# type_hint=Optional[torch.Tensor], -# description="The negative time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." -# ), -# InputParam( -# "prompt_embeds", -# required=True, -# type_hint=torch.Tensor, -# description="The prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "negative_prompt_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " -# ), -# InputParam( -# "timestep_cond", -# type_hint=Optional[torch.Tensor], -# description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." -# ), -# InputParam( -# "mask", -# type_hint=Optional[torch.Tensor], -# description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "masked_image_latents", -# type_hint=Optional[torch.Tensor], -# description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "noise", -# type_hint=Optional[torch.Tensor], -# description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." -# ), -# InputParam( -# "image_latents", -# type_hint=Optional[torch.Tensor], -# description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "ip_adapter_embeds", -# type_hint=Optional[torch.Tensor], -# description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." -# ), -# InputParam( -# "negative_ip_adapter_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." -# ), -# ] - -# @property -# def intermediates_outputs(self) -> List[OutputParam]: -# return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - - -# @staticmethod -# def check_inputs(components, block_state): - -# num_channels_unet = components.unet.config.in_channels -# if num_channels_unet == 9: -# # default case for runwayml/stable-diffusion-inpainting -# if block_state.mask is None or block_state.masked_image_latents is None: -# raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") -# num_channels_latents = block_state.latents.shape[1] -# num_channels_mask = block_state.mask.shape[1] -# num_channels_masked_image = block_state.masked_image_latents.shape[1] -# if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: -# raise ValueError( -# f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" -# f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" -# f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" -# f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" -# " `components.unet` or your `mask_image` or `image` input." -# ) - -# # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components -# @staticmethod -# def prepare_extra_step_kwargs(components, generator, eta): -# # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature -# # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. -# # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 -# # and should be between [0, 1] - -# accepts_eta = "eta" in set(inspect.signature(components.scheduler.step).parameters.keys()) -# extra_step_kwargs = {} -# if accepts_eta: -# extra_step_kwargs["eta"] = eta - -# # check if the scheduler accepts generator -# accepts_generator = "generator" in set(inspect.signature(components.scheduler.step).parameters.keys()) -# if accepts_generator: -# extra_step_kwargs["generator"] = generator -# return extra_step_kwargs - -# @torch.no_grad() -# def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - -# block_state = self.get_block_state(state) -# self.check_inputs(components, block_state) - -# block_state.num_channels_unet = components.unet.config.in_channels -# block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False -# if block_state.disable_guidance: -# components.guider.disable() -# else: -# components.guider.enable() - -# # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline -# block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) -# block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) - -# components.guider.set_input_fields( -# prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), -# add_time_ids=("add_time_ids", "negative_add_time_ids"), -# pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), -# ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), -# ) - -# with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: -# for i, t in enumerate(block_state.timesteps): -# components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) -# guider_data = components.guider.prepare_inputs(block_state) - -# block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) - -# # Prepare for inpainting -# if block_state.num_channels_unet == 9: -# block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) - -# for batch in guider_data: -# components.guider.prepare_models(components.unet) - -# # Prepare additional conditionings -# batch.added_cond_kwargs = { -# "text_embeds": batch.pooled_prompt_embeds, -# "time_ids": batch.add_time_ids, -# } -# if batch.ip_adapter_embeds is not None: -# batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds - -# # Predict the noise residual -# batch.noise_pred = components.unet( -# block_state.scaled_latents, -# t, -# encoder_hidden_states=batch.prompt_embeds, -# timestep_cond=block_state.timestep_cond, -# cross_attention_kwargs=block_state.cross_attention_kwargs, -# added_cond_kwargs=batch.added_cond_kwargs, -# return_dict=False, -# )[0] -# components.guider.cleanup_models(components.unet) - -# # Perform guidance -# block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_data) - -# # Perform scheduler step using the predicted output -# block_state.latents_dtype = block_state.latents.dtype -# block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] - -# if block_state.latents.dtype != block_state.latents_dtype: -# if torch.backends.mps.is_available(): -# # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 -# block_state.latents = block_state.latents.to(block_state.latents_dtype) - -# if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: -# block_state.init_latents_proper = block_state.image_latents -# if i < len(block_state.timesteps) - 1: -# block_state.noise_timestep = block_state.timesteps[i + 1] -# block_state.init_latents_proper = components.scheduler.add_noise( -# block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) -# ) - -# block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents - -# if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): -# progress_bar.update() - -# self.add_block_state(state, block_state) - -# return components, state - - class StableDiffusionXLControlNetInputStep(PipelineBlock): model_name = "stable-diffusion-xl" @@ -2593,355 +2331,6 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt return components, state -from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_modular_denoise_loop import StableDiffusionXLControlNetDenoiseStep -# class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): - -# model_name = "stable-diffusion-xl" - -# @property -# def expected_components(self) -> List[ComponentSpec]: -# return [ -# ComponentSpec( -# "guider", -# ClassifierFreeGuidance, -# config=FrozenDict({"guidance_scale": 7.5}), -# default_creation_method="from_config"), -# ComponentSpec("scheduler", EulerDiscreteScheduler), -# ComponentSpec("unet", UNet2DConditionModel), -# ComponentSpec("controlnet", ControlNetModel), -# ] - -# @property -# def description(self) -> str: -# return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" - -# @property -# def inputs(self) -> List[Tuple[str, Any]]: -# return [ -# InputParam("num_images_per_prompt", default=1), -# InputParam("cross_attention_kwargs"), -# InputParam("generator"), -# InputParam("eta", default=0.0), -# InputParam("controlnet_conditioning_scale", type_hint=float, default=1.0), # can expect either input or intermediate input, (intermediate input if both are passed) -# ] - -# @property -# def intermediates_inputs(self) -> List[str]: -# return [ -# InputParam( -# "controlnet_cond", -# required=True, -# type_hint=torch.Tensor, -# description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "control_guidance_start", -# required=True, -# type_hint=float, -# description="The control guidance start value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "control_guidance_end", -# required=True, -# type_hint=float, -# description="The control guidance end value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "conditioning_scale", -# type_hint=float, -# description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "guess_mode", -# required=True, -# type_hint=bool, -# description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "controlnet_keep", -# required=True, -# type_hint=List[float], -# description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "latents", -# required=True, -# type_hint=torch.Tensor, -# description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." -# ), -# InputParam( -# "batch_size", -# required=True, -# type_hint=int, -# description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." -# ), -# InputParam( -# "timesteps", -# required=True, -# type_hint=torch.Tensor, -# description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." -# ), -# InputParam( -# "prompt_embeds", -# required=True, -# type_hint=torch.Tensor, -# description="The prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "negative_prompt_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "add_time_ids", -# required=True, -# type_hint=torch.Tensor, -# description="The time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." -# ), -# InputParam( -# "negative_add_time_ids", -# type_hint=Optional[torch.Tensor], -# description="The negative time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." -# ), -# InputParam( -# "pooled_prompt_embeds", -# required=True, -# type_hint=torch.Tensor, -# description="The pooled prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "negative_pooled_prompt_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "timestep_cond", -# type_hint=Optional[torch.Tensor], -# description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" -# ), -# InputParam( -# "mask", -# type_hint=Optional[torch.Tensor], -# description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "masked_image_latents", -# type_hint=Optional[torch.Tensor], -# description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "noise", -# type_hint=Optional[torch.Tensor], -# description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." -# ), -# InputParam( -# "image_latents", -# type_hint=Optional[torch.Tensor], -# description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "crops_coords", -# type_hint=Optional[Tuple[int]], -# description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." -# ), -# InputParam( -# "ip_adapter_embeds", -# type_hint=Optional[torch.Tensor], -# description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." -# ), -# InputParam( -# "negative_ip_adapter_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." -# ), -# InputParam( -# "num_inference_steps", -# required=True, -# type_hint=int, -# description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." -# ), -# InputParam(kwargs_type="controlnet_kwargs", description="additional kwargs for controlnet") -# ] - -# @property -# def intermediates_outputs(self) -> List[OutputParam]: -# return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - -# @staticmethod -# def check_inputs(components, block_state): - -# num_channels_unet = components.unet.config.in_channels -# if num_channels_unet == 9: -# # default case for runwayml/stable-diffusion-inpainting -# if block_state.mask is None or block_state.masked_image_latents is None: -# raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") -# num_channels_latents = block_state.latents.shape[1] -# num_channels_mask = block_state.mask.shape[1] -# num_channels_masked_image = block_state.masked_image_latents.shape[1] -# if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: -# raise ValueError( -# f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" -# f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" -# f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" -# f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" -# " `components.unet` or your `mask_image` or `image` input." -# ) -# @staticmethod -# def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): - -# accepted_kwargs = set(inspect.signature(func).parameters.keys()) -# extra_kwargs = {} -# for key, value in kwargs.items(): -# if key in accepted_kwargs and key not in exclude_kwargs: -# extra_kwargs[key] = value - -# return extra_kwargs - - -# @torch.no_grad() -# def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - -# block_state = self.get_block_state(state) -# self.check_inputs(components, block_state) -# block_state.device = components._execution_device -# print(f" block_state: {block_state}") - -# controlnet = unwrap_module(components.controlnet) - -# # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline -# block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) -# block_state.extra_controlnet_kwargs = self.prepare_extra_kwargs(controlnet.forward, exclude_kwargs=["controlnet_cond", "conditioning_scale", "guess_mode"], **block_state.controlnet_kwargs) - -# block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) - -# # (1) setup guider -# # disable for LCMs -# block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False -# if block_state.disable_guidance: -# components.guider.disable() -# else: -# components.guider.enable() -# components.guider.set_input_fields( -# prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), -# add_time_ids=("add_time_ids", "negative_add_time_ids"), -# pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), -# ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), -# ) - -# # (5) Denoise loop -# with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: -# for i, t in enumerate(block_state.timesteps): - -# # prepare latent input for unet -# block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) -# # adjust latent input for inpainting -# block_state.num_channels_unet = components.unet.config.in_channels -# if block_state.num_channels_unet == 9: -# block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) - - -# # cond_scale (controlnet input) -# if isinstance(block_state.controlnet_keep[i], list): -# block_state.cond_scale = [c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i])] -# else: -# block_state.controlnet_cond_scale = block_state.conditioning_scale -# if isinstance(block_state.controlnet_cond_scale, list): -# block_state.controlnet_cond_scale = block_state.controlnet_cond_scale[0] -# block_state.cond_scale = block_state.controlnet_cond_scale * block_state.controlnet_keep[i] - -# # default controlnet output/unet input for guess mode + conditional path -# block_state.down_block_res_samples_zeros = None -# block_state.mid_block_res_sample_zeros = None - -# # guided denoiser step -# components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) -# guider_state = components.guider.prepare_inputs(block_state) - -# for guider_state_batch in guider_state: -# components.guider.prepare_models(components.unet) - -# # Prepare additional conditionings -# guider_state_batch.added_cond_kwargs = { -# "text_embeds": guider_state_batch.pooled_prompt_embeds, -# "time_ids": guider_state_batch.add_time_ids, -# } -# if guider_state_batch.ip_adapter_embeds is not None: -# guider_state_batch.added_cond_kwargs["image_embeds"] = guider_state_batch.ip_adapter_embeds - -# # Prepare controlnet additional conditionings -# guider_state_batch.controlnet_added_cond_kwargs = { -# "text_embeds": guider_state_batch.pooled_prompt_embeds, -# "time_ids": guider_state_batch.add_time_ids, -# } - -# if block_state.guess_mode and not components.guider.is_conditional: -# # guider always run uncond batch first, so these tensors should be set already -# guider_state_batch.down_block_res_samples = block_state.down_block_res_samples_zeros -# guider_state_batch.mid_block_res_sample = block_state.mid_block_res_sample_zeros -# else: -# guider_state_batch.down_block_res_samples, guider_state_batch.mid_block_res_sample = components.controlnet( -# block_state.scaled_latents, -# t, -# encoder_hidden_states=guider_state_batch.prompt_embeds, -# controlnet_cond=block_state.controlnet_cond, -# conditioning_scale=block_state.conditioning_scale, -# guess_mode=block_state.guess_mode, -# added_cond_kwargs=guider_state_batch.controlnet_added_cond_kwargs, -# return_dict=False, -# **block_state.extra_controlnet_kwargs, -# ) - -# if block_state.down_block_res_samples_zeros is None: -# block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in guider_state_batch.down_block_res_samples] -# if block_state.mid_block_res_sample_zeros is None: -# block_state.mid_block_res_sample_zeros = torch.zeros_like(guider_state_batch.mid_block_res_sample) - - - -# guider_state_batch.noise_pred = components.unet( -# block_state.scaled_latents, -# t, -# encoder_hidden_states=guider_state_batch.prompt_embeds, -# timestep_cond=block_state.timestep_cond, -# cross_attention_kwargs=block_state.cross_attention_kwargs, -# added_cond_kwargs=guider_state_batch.added_cond_kwargs, -# down_block_additional_residuals=guider_state_batch.down_block_res_samples, -# mid_block_additional_residual=guider_state_batch.mid_block_res_sample, -# return_dict=False, -# )[0] -# components.guider.cleanup_models(components.unet) - -# # Perform guidance -# block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_state) - -# # Perform scheduler step using the predicted output -# block_state.latents_dtype = block_state.latents.dtype -# block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] - -# if block_state.latents.dtype != block_state.latents_dtype: -# if torch.backends.mps.is_available(): -# # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 -# block_state.latents = block_state.latents.to(block_state.latents_dtype) - -# # adjust latent for inpainting -# if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: -# block_state.init_latents_proper = block_state.image_latents -# if i < len(block_state.timesteps) - 1: -# block_state.noise_timestep = block_state.timesteps[i + 1] -# block_state.init_latents_proper = components.scheduler.add_noise( -# block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) -# ) - -# block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents - -# if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): -# progress_bar.update() - -# self.add_block_state(state, block_state) - -# return components, state - - class StableDiffusionXLControlNetUnionInputStep(PipelineBlock): model_name = "stable-diffusion-xl" @@ -3123,6 +2512,13 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt return components, state +class StableDiffusionXLControlNetAutoInput(AutoPipelineBlocks): + + block_classes = [StableDiffusionXLControlNetUnionInputStep, StableDiffusionXLControlNetInputStep] + block_names = ["controlnet_union", "controlnet"] + block_trigger_inputs = ["control_mode", "control_image"] + + class StableDiffusionXLDecodeLatentsStep(PipelineBlock): model_name = "stable-diffusion-xl" @@ -3316,8 +2712,8 @@ def description(self): # Before denoise class StableDiffusionXLBeforeDenoiseStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLInputStep, StableDiffusionXLSetTimestepsStep, StableDiffusionXLPrepareLatentsStep, StableDiffusionXLPrepareAdditionalConditioningStep] - block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"] + block_classes = [StableDiffusionXLInputStep, StableDiffusionXLSetTimestepsStep, StableDiffusionXLPrepareLatentsStep, StableDiffusionXLPrepareAdditionalConditioningStep, StableDiffusionXLControlNetAutoInput] + block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond", "controlnet_input"] @property def description(self): @@ -3326,12 +2722,13 @@ def description(self): " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ " - `StableDiffusionXLSetTimestepsStep` is used to set the timesteps\n" + \ " - `StableDiffusionXLPrepareLatentsStep` is used to prepare the latents\n" + \ - " - `StableDiffusionXLPrepareAdditionalConditioningStep` is used to prepare the additional conditioning" + " - `StableDiffusionXLPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + \ + " - `StableDiffusionXLControlNetAutoInput` is used to prepare the controlnet input" class StableDiffusionXLImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLInputStep, StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLImg2ImgPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep] - block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"] + block_classes = [StableDiffusionXLInputStep, StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLImg2ImgPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, StableDiffusionXLControlNetAutoInput] + block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond", "controlnet_input"] @property def description(self): @@ -3340,12 +2737,13 @@ def description(self): " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" + \ " - `StableDiffusionXLImg2ImgPrepareLatentsStep` is used to prepare the latents\n" + \ - " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning" + " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + \ + " - `StableDiffusionXLControlNetAutoInput` is used to prepare the controlnet input" class StableDiffusionXLInpaintBeforeDenoiseStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLInputStep, StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLInpaintPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep] - block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"] + block_classes = [StableDiffusionXLInputStep, StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLInpaintPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, StableDiffusionXLControlNetAutoInput] + block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond", "controlnet_input"] @property def description(self): @@ -3354,29 +2752,8 @@ def description(self): " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" + \ " - `StableDiffusionXLInpaintPrepareLatentsStep` is used to prepare the latents\n" + \ - " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning" - -class StableDiffusionXLControlNetStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLControlNetInputStep, StableDiffusionXLControlNetDenoiseStep] - block_names = ["prepare_input", "denoise"] - - @property - def description(self): - return "Controlnet step that denoise the latents.\n" + \ - "This is a sequential pipeline blocks:\n" + \ - " - `StableDiffusionXLControlNetInputStep` is used to prepare the inputs for the denoise step.\n" + \ - " - `StableDiffusionXLControlNetDenoiseStep` is used to denoise the latents." - -class StableDiffusionXLControlNetUnionStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLControlNetUnionInputStep, StableDiffusionXLControlNetDenoiseStep] - block_names = ["prepare_input", "denoise"] - - @property - def description(self): - return "ControlNetUnion step that denoises the latents.\n" + \ - "This is a sequential pipeline blocks:\n" + \ - " - `StableDiffusionXLControlNetUnionInputStep` is used to prepare the inputs for the denoise step.\n" + \ - " - `StableDiffusionXLControlNetDenoiseStep` is used to denoise the latents using the ControlNetUnion model." + " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + \ + " - `StableDiffusionXLControlNetAutoInput` is used to prepare the controlnet input" class StableDiffusionXLAutoBeforeDenoiseStep(AutoPipelineBlocks): @@ -3387,24 +2764,27 @@ class StableDiffusionXLAutoBeforeDenoiseStep(AutoPipelineBlocks): @property def description(self): return "Before denoise step that prepare the inputs for the denoise step.\n" + \ - "This is an auto pipeline block that works for text2img, img2img and inpainting tasks.\n" + \ + "This is an auto pipeline block that works for text2img, img2img and inpainting tasks as well as controlnet, controlnet_union.\n" + \ " - `StableDiffusionXLInpaintBeforeDenoiseStep` (inpaint) is used when both `mask` and `image_latents` are provided.\n" + \ " - `StableDiffusionXLImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n" + \ - " - `StableDiffusionXLBeforeDenoiseStep` (text2img) is used when both `image_latents` and `mask` are not provided." + " - `StableDiffusionXLBeforeDenoiseStep` (text2img) is used when both `image_latents` and `mask` are not provided.\n" + \ + " - `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided.\n" + \ + " - `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided." -# Denoise -class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLControlNetUnionStep, StableDiffusionXLControlNetStep, StableDiffusionXLDenoiseStep] - block_names = ["controlnet_union", "controlnet", "unet"] - block_trigger_inputs = ["control_mode", "control_image", None] +# # Denoise +from .pipeline_stable_diffusion_xl_modular_denoise_loop import StableDiffusionXLDenoiseStep, StableDiffusionXLControlNetDenoiseStep, StableDiffusionXLAutoDenoiseStep +# class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): +# block_classes = [StableDiffusionXLControlNetUnionStep, StableDiffusionXLControlNetStep, StableDiffusionXLDenoiseStep] +# block_names = ["controlnet_union", "controlnet", "unet"] +# block_trigger_inputs = ["control_mode", "control_image", None] - @property - def description(self): - return "Denoise step that denoise the latents.\n" + \ - "This is an auto pipeline block that works for controlnet, controlnet_union and no controlnet.\n" + \ - " - `StableDiffusionXLControlNetUnionStep` (controlnet_union) is used when both `control_mode` and `control_image` are provided.\n" + \ - " - `StableDiffusionXLControlNetStep` (controlnet) is used when `control_image` is provided.\n" + \ - " - `StableDiffusionXLDenoiseStep` (unet only) is used when both `control_mode` and `control_image` are not provided." +# @property +# def description(self): +# return "Denoise step that denoise the latents.\n" + \ +# "This is an auto pipeline block that works for controlnet, controlnet_union and no controlnet.\n" + \ +# " - `StableDiffusionXLControlNetUnionStep` (controlnet_union) is used when both `control_mode` and `control_image` are provided.\n" + \ +# " - `StableDiffusionXLControlNetStep` (controlnet) is used when `control_image` is provided.\n" + \ +# " - `StableDiffusionXLDenoiseStep` (unet only) is used when both `control_mode` and `control_image` are not provided." # After denoise class StableDiffusionXLDecodeStep(SequentialPipelineBlocks): @@ -3474,6 +2854,7 @@ def description(self): # always assuming you want to do guidance in the Guiders. So, negative embeddings are prepared regardless of what the # configuration of guider is. + # block mapping TEXT2IMAGE_BLOCKS = OrderedDict([ ("text_encoder", StableDiffusionXLTextEncoderStep), @@ -3511,11 +2892,13 @@ def description(self): ]) CONTROLNET_BLOCKS = OrderedDict([ - ("denoise", StableDiffusionXLControlNetStep), + ("controlnet_input", StableDiffusionXLControlNetInputStep), + ("denoise", StableDiffusionXLControlNetDenoiseStep), ]) CONTROLNET_UNION_BLOCKS = OrderedDict([ - ("denoise", StableDiffusionXLControlNetUnionStep), + ("controlnet_input", StableDiffusionXLControlNetUnionInputStep), + ("denoise", StableDiffusionXLControlNetDenoiseStep), ]) IP_ADAPTER_BLOCKS = OrderedDict([ diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular_denoise_loop.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular_denoise_loop.py index 92c07854fc74..63d0784a5762 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular_denoise_loop.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular_denoise_loop.py @@ -22,10 +22,11 @@ from ...models import ControlNetModel, UNet2DConditionModel from ...schedulers import EulerDiscreteScheduler from ...utils import logging -from ...utils.torch_utils import unwrap_module +from ...utils.torch_utils import unwrap_module from ..modular_pipeline import ( PipelineBlock, PipelineState, + AutoPipelineBlocks, LoopSequentialPipelineBlocks, InputParam, OutputParam, @@ -42,7 +43,7 @@ # YiYi experimenting composible denoise loop # loop step (1): prepare latent input for denoiser -class StableDiffusionXLDenoiseLoopLatentsStep(PipelineBlock): +class StableDiffusionXLDenoiseLoopBeforeDenoiser(PipelineBlock): model_name = "stable-diffusion-xl" @@ -83,7 +84,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, block_state: Bloc return components, block_state # loop step (1): prepare latent input for denoiser (with inpainting) -class StableDiffusionXLDenoiseLoopInpaintLatentsStep(PipelineBlock): +class StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser(PipelineBlock): model_name = "stable-diffusion-xl" @@ -145,7 +146,7 @@ def check_inputs(components, block_state): ) @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, loop_idx: int, t: int): + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): self.check_inputs(components, block_state) @@ -157,7 +158,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, block_state: Bloc return components, block_state # loop step (2): denoise the latents with guidance -class StableDiffusionXLDenoiseLoopDenoiserStep(PipelineBlock): +class StableDiffusionXLDenoiseLoopDenoiser(PipelineBlock): model_name = "stable-diffusion-xl" @@ -267,7 +268,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, block_state: Bloc return components, block_state # loop step (2): denoise the latents with guidance (with controlnet) -class StableDiffusionXLDenoiseLoopControlNetDenoiserStep(PipelineBlock): +class StableDiffusionXLControlNetDenoiseLoopDenoiser(PipelineBlock): model_name = "stable-diffusion-xl" @@ -468,7 +469,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, block_state: Bloc return components, block_state # loop step (3): scheduler step to update latents -class StableDiffusionXLDenoiseLoopUpdateLatentsStep(PipelineBlock): +class StableDiffusionXLDenoiseLoopAfterDenoiser(PipelineBlock): model_name = "stable-diffusion-xl" @@ -535,8 +536,8 @@ def __call__(self, components: StableDiffusionXLModularLoader, block_state: Bloc return components, block_state - -class StableDiffusionXLDenoiseLoopInpaintUpdateLatentsStep(PipelineBlock): +# loop step (3): scheduler step to update latents (with inpainting) +class StableDiffusionXLInpaintDenoiseLoopAfterDenoiser(PipelineBlock): model_name = "stable-diffusion-xl" @@ -643,7 +644,7 @@ def __call__(self, components: StableDiffusionXLModularLoader, block_state: Bloc # the loop wrapper that iterates over the timesteps -class StableDiffusionXLDenoiseLoop(LoopSequentialPipelineBlocks): +class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks): model_name = "stable-diffusion-xl" @@ -706,24 +707,657 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt return components, state +# composing the denoising loops +class StableDiffusionXLDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [StableDiffusionXLDenoiseLoopBeforeDenoiser, StableDiffusionXLDenoiseLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + +# control_cond +class StableDiffusionXLControlNetDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [StableDiffusionXLDenoiseLoopBeforeDenoiser, StableDiffusionXLControlNetDenoiseLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + +# mask +class StableDiffusionXLInpaintDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser, StableDiffusionXLDenoiseLoopDenoiser, StableDiffusionXLInpaintDenoiseLoopAfterDenoiser] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + +# control_cond + mask +class StableDiffusionXLInpaintControlNetDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser, StableDiffusionXLControlNetDenoiseLoopDenoiser, StableDiffusionXLInpaintDenoiseLoopAfterDenoiser] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + + +# all task without controlnet +class StableDiffusionXLDenoiseStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintDenoiseLoop, StableDiffusionXLDenoiseLoop] + block_names = ["inpaint_denoise", "denoise"] + block_trigger_inputs = ["mask", None] + +# all task with controlnet +class StableDiffusionXLControlNetDenoiseStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintControlNetDenoiseLoop, StableDiffusionXLControlNetDenoiseLoop] + block_names = ["inpaint_controlnet_denoise", "controlnet_denoise"] + block_trigger_inputs = ["mask", None] + +# all task with or without controlnet +class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLControlNetDenoiseStep, StableDiffusionXLDenoiseStep] + block_names = ["controlnet_denoise", "denoise"] + block_trigger_inputs = ["controlnet_cond", None] + + + + + + + +# YiYi Notes: alternatively, this is you can just write the denoise loop using a pipeline block, easier but not composible +# class StableDiffusionXLDenoiseStep(PipelineBlock): + +# model_name = "stable-diffusion-xl" + +# @property +# def expected_components(self) -> List[ComponentSpec]: +# return [ +# ComponentSpec( +# "guider", +# ClassifierFreeGuidance, +# config=FrozenDict({"guidance_scale": 7.5}), +# default_creation_method="from_config"), +# ComponentSpec("scheduler", EulerDiscreteScheduler), +# ComponentSpec("unet", UNet2DConditionModel), +# ] + +# @property +# def description(self) -> str: +# return ( +# "Step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process" +# ) + +# @property +# def inputs(self) -> List[Tuple[str, Any]]: +# return [ +# InputParam("cross_attention_kwargs"), +# InputParam("generator"), +# InputParam("eta", default=0.0), +# InputParam("num_images_per_prompt", default=1), +# ] + +# @property +# def intermediates_inputs(self) -> List[str]: +# return [ +# InputParam( +# "latents", +# required=True, +# type_hint=torch.Tensor, +# description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." +# ), +# InputParam( +# "batch_size", +# required=True, +# type_hint=int, +# description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." +# ), +# InputParam( +# "timesteps", +# required=True, +# type_hint=torch.Tensor, +# description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." +# ), +# InputParam( +# "num_inference_steps", +# required=True, +# type_hint=int, +# description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." +# ), +# InputParam( +# "pooled_prompt_embeds", +# required=True, +# type_hint=torch.Tensor, +# description="The pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "negative_pooled_prompt_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " +# ), +# InputParam( +# "add_time_ids", +# required=True, +# type_hint=torch.Tensor, +# description="The time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." +# ), +# InputParam( +# "negative_add_time_ids", +# type_hint=Optional[torch.Tensor], +# description="The negative time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." +# ), +# InputParam( +# "prompt_embeds", +# required=True, +# type_hint=torch.Tensor, +# description="The prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "negative_prompt_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " +# ), +# InputParam( +# "timestep_cond", +# type_hint=Optional[torch.Tensor], +# description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." +# ), +# InputParam( +# "mask", +# type_hint=Optional[torch.Tensor], +# description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "masked_image_latents", +# type_hint=Optional[torch.Tensor], +# description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "noise", +# type_hint=Optional[torch.Tensor], +# description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." +# ), +# InputParam( +# "image_latents", +# type_hint=Optional[torch.Tensor], +# description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "ip_adapter_embeds", +# type_hint=Optional[torch.Tensor], +# description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." +# ), +# InputParam( +# "negative_ip_adapter_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." +# ), +# ] + +# @property +# def intermediates_outputs(self) -> List[OutputParam]: +# return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + +# @staticmethod +# def check_inputs(components, block_state): + +# num_channels_unet = components.unet.config.in_channels +# if num_channels_unet == 9: +# # default case for runwayml/stable-diffusion-inpainting +# if block_state.mask is None or block_state.masked_image_latents is None: +# raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") +# num_channels_latents = block_state.latents.shape[1] +# num_channels_mask = block_state.mask.shape[1] +# num_channels_masked_image = block_state.masked_image_latents.shape[1] +# if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: +# raise ValueError( +# f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" +# f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" +# f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" +# f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" +# " `components.unet` or your `mask_image` or `image` input." +# ) + +# # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components +# @staticmethod +# def prepare_extra_step_kwargs(components, generator, eta): +# # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature +# # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. +# # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 +# # and should be between [0, 1] + +# accepts_eta = "eta" in set(inspect.signature(components.scheduler.step).parameters.keys()) +# extra_step_kwargs = {} +# if accepts_eta: +# extra_step_kwargs["eta"] = eta + +# # check if the scheduler accepts generator +# accepts_generator = "generator" in set(inspect.signature(components.scheduler.step).parameters.keys()) +# if accepts_generator: +# extra_step_kwargs["generator"] = generator +# return extra_step_kwargs + +# @torch.no_grad() +# def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + +# block_state = self.get_block_state(state) +# self.check_inputs(components, block_state) + +# block_state.num_channels_unet = components.unet.config.in_channels +# block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False +# if block_state.disable_guidance: +# components.guider.disable() +# else: +# components.guider.enable() + +# # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline +# block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) +# block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) + +# components.guider.set_input_fields( +# prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), +# add_time_ids=("add_time_ids", "negative_add_time_ids"), +# pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), +# ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), +# ) + +# with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: +# for i, t in enumerate(block_state.timesteps): +# components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) +# guider_data = components.guider.prepare_inputs(block_state) + +# block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) + +# # Prepare for inpainting +# if block_state.num_channels_unet == 9: +# block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) + +# for batch in guider_data: +# components.guider.prepare_models(components.unet) + +# # Prepare additional conditionings +# batch.added_cond_kwargs = { +# "text_embeds": batch.pooled_prompt_embeds, +# "time_ids": batch.add_time_ids, +# } +# if batch.ip_adapter_embeds is not None: +# batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds + +# # Predict the noise residual +# batch.noise_pred = components.unet( +# block_state.scaled_latents, +# t, +# encoder_hidden_states=batch.prompt_embeds, +# timestep_cond=block_state.timestep_cond, +# cross_attention_kwargs=block_state.cross_attention_kwargs, +# added_cond_kwargs=batch.added_cond_kwargs, +# return_dict=False, +# )[0] +# components.guider.cleanup_models(components.unet) + +# # Perform guidance +# block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_data) + +# # Perform scheduler step using the predicted output +# block_state.latents_dtype = block_state.latents.dtype +# block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] + +# if block_state.latents.dtype != block_state.latents_dtype: +# if torch.backends.mps.is_available(): +# # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 +# block_state.latents = block_state.latents.to(block_state.latents_dtype) + +# if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: +# block_state.init_latents_proper = block_state.image_latents +# if i < len(block_state.timesteps) - 1: +# block_state.noise_timestep = block_state.timesteps[i + 1] +# block_state.init_latents_proper = components.scheduler.add_noise( +# block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) +# ) + +# block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents + +# if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): +# progress_bar.update() + +# self.add_block_state(state, block_state) + +# return components, state + + + +# class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): + +# model_name = "stable-diffusion-xl" + +# @property +# def expected_components(self) -> List[ComponentSpec]: +# return [ +# ComponentSpec( +# "guider", +# ClassifierFreeGuidance, +# config=FrozenDict({"guidance_scale": 7.5}), +# default_creation_method="from_config"), +# ComponentSpec("scheduler", EulerDiscreteScheduler), +# ComponentSpec("unet", UNet2DConditionModel), +# ComponentSpec("controlnet", ControlNetModel), +# ] + +# @property +# def description(self) -> str: +# return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" + +# @property +# def inputs(self) -> List[Tuple[str, Any]]: +# return [ +# InputParam("num_images_per_prompt", default=1), +# InputParam("cross_attention_kwargs"), +# InputParam("generator"), +# InputParam("eta", default=0.0), +# InputParam("controlnet_conditioning_scale", type_hint=float, default=1.0), # can expect either input or intermediate input, (intermediate input if both are passed) +# ] + +# @property +# def intermediates_inputs(self) -> List[str]: +# return [ +# InputParam( +# "controlnet_cond", +# required=True, +# type_hint=torch.Tensor, +# description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "control_guidance_start", +# required=True, +# type_hint=float, +# description="The control guidance start value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "control_guidance_end", +# required=True, +# type_hint=float, +# description="The control guidance end value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "conditioning_scale", +# type_hint=float, +# description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "guess_mode", +# required=True, +# type_hint=bool, +# description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "controlnet_keep", +# required=True, +# type_hint=List[float], +# description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "latents", +# required=True, +# type_hint=torch.Tensor, +# description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." +# ), +# InputParam( +# "batch_size", +# required=True, +# type_hint=int, +# description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." +# ), +# InputParam( +# "timesteps", +# required=True, +# type_hint=torch.Tensor, +# description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." +# ), +# InputParam( +# "prompt_embeds", +# required=True, +# type_hint=torch.Tensor, +# description="The prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "negative_prompt_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "add_time_ids", +# required=True, +# type_hint=torch.Tensor, +# description="The time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." +# ), +# InputParam( +# "negative_add_time_ids", +# type_hint=Optional[torch.Tensor], +# description="The negative time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." +# ), +# InputParam( +# "pooled_prompt_embeds", +# required=True, +# type_hint=torch.Tensor, +# description="The pooled prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "negative_pooled_prompt_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "timestep_cond", +# type_hint=Optional[torch.Tensor], +# description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" +# ), +# InputParam( +# "mask", +# type_hint=Optional[torch.Tensor], +# description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "masked_image_latents", +# type_hint=Optional[torch.Tensor], +# description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "noise", +# type_hint=Optional[torch.Tensor], +# description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." +# ), +# InputParam( +# "image_latents", +# type_hint=Optional[torch.Tensor], +# description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "crops_coords", +# type_hint=Optional[Tuple[int]], +# description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." +# ), +# InputParam( +# "ip_adapter_embeds", +# type_hint=Optional[torch.Tensor], +# description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." +# ), +# InputParam( +# "negative_ip_adapter_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." +# ), +# InputParam( +# "num_inference_steps", +# required=True, +# type_hint=int, +# description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." +# ), +# InputParam(kwargs_type="controlnet_kwargs", description="additional kwargs for controlnet") +# ] + +# @property +# def intermediates_outputs(self) -> List[OutputParam]: +# return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + +# @staticmethod +# def check_inputs(components, block_state): + +# num_channels_unet = components.unet.config.in_channels +# if num_channels_unet == 9: +# # default case for runwayml/stable-diffusion-inpainting +# if block_state.mask is None or block_state.masked_image_latents is None: +# raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") +# num_channels_latents = block_state.latents.shape[1] +# num_channels_mask = block_state.mask.shape[1] +# num_channels_masked_image = block_state.masked_image_latents.shape[1] +# if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: +# raise ValueError( +# f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" +# f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" +# f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" +# f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" +# " `components.unet` or your `mask_image` or `image` input." +# ) +# @staticmethod +# def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + +# accepted_kwargs = set(inspect.signature(func).parameters.keys()) +# extra_kwargs = {} +# for key, value in kwargs.items(): +# if key in accepted_kwargs and key not in exclude_kwargs: +# extra_kwargs[key] = value + +# return extra_kwargs + + +# @torch.no_grad() +# def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + +# block_state = self.get_block_state(state) +# self.check_inputs(components, block_state) +# block_state.device = components._execution_device +# print(f" block_state: {block_state}") -# StableDiffusionXLControlNetDenoiseStep - -class StableDiffusionXLDenoiseStep(StableDiffusionXLDenoiseLoop): - block_classes = [StableDiffusionXLDenoiseLoopLatentsStep, StableDiffusionXLDenoiseLoopDenoiserStep, StableDiffusionXLDenoiseLoopUpdateLatentsStep] - block_names = ["prepare_latents", "denoiser", "update_latents"] - -class StableDiffusionXLControlNetDenoiseStep(StableDiffusionXLDenoiseLoop): - block_classes = [StableDiffusionXLDenoiseLoopLatentsStep, StableDiffusionXLDenoiseLoopControlNetDenoiserStep, StableDiffusionXLDenoiseLoopUpdateLatentsStep] - block_names = ["prepare_latents", "denoiser", "update_latents"] - -class StableDiffusionXLInpaintDenoiseStep(StableDiffusionXLDenoiseLoop): - block_classes = [StableDiffusionXLDenoiseLoopInpaintLatentsStep, StableDiffusionXLDenoiseLoopDenoiserStep, StableDiffusionXLDenoiseLoopInpaintUpdateLatentsStep] - block_names = ["prepare_latents", "denoiser", "update_latents"] - -class StableDiffusionXLInpaintControlNetDenoiseStep(StableDiffusionXLDenoiseLoop): - block_classes = [StableDiffusionXLDenoiseLoopInpaintLatentsStep, StableDiffusionXLDenoiseLoopControlNetDenoiserStep, StableDiffusionXLDenoiseLoopInpaintUpdateLatentsStep] - block_names = ["prepare_latents", "denoiser", "update_latents"] +# controlnet = unwrap_module(components.controlnet) +# # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline +# block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) +# block_state.extra_controlnet_kwargs = self.prepare_extra_kwargs(controlnet.forward, exclude_kwargs=["controlnet_cond", "conditioning_scale", "guess_mode"], **block_state.controlnet_kwargs) +# block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) + +# # (1) setup guider +# # disable for LCMs +# block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False +# if block_state.disable_guidance: +# components.guider.disable() +# else: +# components.guider.enable() +# components.guider.set_input_fields( +# prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), +# add_time_ids=("add_time_ids", "negative_add_time_ids"), +# pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), +# ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), +# ) + +# # (5) Denoise loop +# with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: +# for i, t in enumerate(block_state.timesteps): + +# # prepare latent input for unet +# block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) +# # adjust latent input for inpainting +# block_state.num_channels_unet = components.unet.config.in_channels +# if block_state.num_channels_unet == 9: +# block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) + + +# # cond_scale (controlnet input) +# if isinstance(block_state.controlnet_keep[i], list): +# block_state.cond_scale = [c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i])] +# else: +# block_state.controlnet_cond_scale = block_state.conditioning_scale +# if isinstance(block_state.controlnet_cond_scale, list): +# block_state.controlnet_cond_scale = block_state.controlnet_cond_scale[0] +# block_state.cond_scale = block_state.controlnet_cond_scale * block_state.controlnet_keep[i] + +# # default controlnet output/unet input for guess mode + conditional path +# block_state.down_block_res_samples_zeros = None +# block_state.mid_block_res_sample_zeros = None + +# # guided denoiser step +# components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) +# guider_state = components.guider.prepare_inputs(block_state) + +# for guider_state_batch in guider_state: +# components.guider.prepare_models(components.unet) + +# # Prepare additional conditionings +# guider_state_batch.added_cond_kwargs = { +# "text_embeds": guider_state_batch.pooled_prompt_embeds, +# "time_ids": guider_state_batch.add_time_ids, +# } +# if guider_state_batch.ip_adapter_embeds is not None: +# guider_state_batch.added_cond_kwargs["image_embeds"] = guider_state_batch.ip_adapter_embeds + +# # Prepare controlnet additional conditionings +# guider_state_batch.controlnet_added_cond_kwargs = { +# "text_embeds": guider_state_batch.pooled_prompt_embeds, +# "time_ids": guider_state_batch.add_time_ids, +# } + +# if block_state.guess_mode and not components.guider.is_conditional: +# # guider always run uncond batch first, so these tensors should be set already +# guider_state_batch.down_block_res_samples = block_state.down_block_res_samples_zeros +# guider_state_batch.mid_block_res_sample = block_state.mid_block_res_sample_zeros +# else: +# guider_state_batch.down_block_res_samples, guider_state_batch.mid_block_res_sample = components.controlnet( +# block_state.scaled_latents, +# t, +# encoder_hidden_states=guider_state_batch.prompt_embeds, +# controlnet_cond=block_state.controlnet_cond, +# conditioning_scale=block_state.conditioning_scale, +# guess_mode=block_state.guess_mode, +# added_cond_kwargs=guider_state_batch.controlnet_added_cond_kwargs, +# return_dict=False, +# **block_state.extra_controlnet_kwargs, +# ) + +# if block_state.down_block_res_samples_zeros is None: +# block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in guider_state_batch.down_block_res_samples] +# if block_state.mid_block_res_sample_zeros is None: +# block_state.mid_block_res_sample_zeros = torch.zeros_like(guider_state_batch.mid_block_res_sample) + + + +# guider_state_batch.noise_pred = components.unet( +# block_state.scaled_latents, +# t, +# encoder_hidden_states=guider_state_batch.prompt_embeds, +# timestep_cond=block_state.timestep_cond, +# cross_attention_kwargs=block_state.cross_attention_kwargs, +# added_cond_kwargs=guider_state_batch.added_cond_kwargs, +# down_block_additional_residuals=guider_state_batch.down_block_res_samples, +# mid_block_additional_residual=guider_state_batch.mid_block_res_sample, +# return_dict=False, +# )[0] +# components.guider.cleanup_models(components.unet) + +# # Perform guidance +# block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_state) + +# # Perform scheduler step using the predicted output +# block_state.latents_dtype = block_state.latents.dtype +# block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] + +# if block_state.latents.dtype != block_state.latents_dtype: +# if torch.backends.mps.is_available(): +# # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 +# block_state.latents = block_state.latents.to(block_state.latents_dtype) + +# # adjust latent for inpainting +# if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: +# block_state.init_latents_proper = block_state.image_latents +# if i < len(block_state.timesteps) - 1: +# block_state.noise_timestep = block_state.timesteps[i + 1] +# block_state.init_latents_proper = components.scheduler.add_noise( +# block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) +# ) + +# block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents + +# if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): +# progress_bar.update() + +# self.add_block_state(state, block_state) +# return components, state \ No newline at end of file From cf01aaeb49a2632458113f4572dd3929426bd009 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 10 May 2025 03:49:30 +0200 Subject: [PATCH 14/38] update imports on guiders --- src/diffusers/guiders/adaptive_projected_guidance.py | 2 +- src/diffusers/guiders/auto_guidance.py | 2 +- src/diffusers/guiders/classifier_free_guidance.py | 2 +- src/diffusers/guiders/classifier_free_zero_star_guidance.py | 2 +- src/diffusers/guiders/guider_utils.py | 4 ++-- src/diffusers/guiders/skip_layer_guidance.py | 2 +- src/diffusers/guiders/smoothed_energy_guidance.py | 2 +- src/diffusers/guiders/tangential_classifier_free_guidance.py | 2 +- 8 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/diffusers/guiders/adaptive_projected_guidance.py b/src/diffusers/guiders/adaptive_projected_guidance.py index 83e93c15ff1d..ef2f3f2c8420 100644 --- a/src/diffusers/guiders/adaptive_projected_guidance.py +++ b/src/diffusers/guiders/adaptive_projected_guidance.py @@ -20,7 +20,7 @@ from .guider_utils import BaseGuidance, rescale_noise_cfg if TYPE_CHECKING: - from ..pipelines.modular_pipeline import BlockState + from ..modular_pipelines.modular_pipeline import BlockState class AdaptiveProjectedGuidance(BaseGuidance): diff --git a/src/diffusers/guiders/auto_guidance.py b/src/diffusers/guiders/auto_guidance.py index 8bb6083781c2..791cc582add2 100644 --- a/src/diffusers/guiders/auto_guidance.py +++ b/src/diffusers/guiders/auto_guidance.py @@ -22,7 +22,7 @@ from .guider_utils import BaseGuidance, rescale_noise_cfg if TYPE_CHECKING: - from ..pipelines.modular_pipeline import BlockState + from ..modular_pipelines.modular_pipeline import BlockState class AutoGuidance(BaseGuidance): diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py index 429392e3f9c6..a459e51cd083 100644 --- a/src/diffusers/guiders/classifier_free_guidance.py +++ b/src/diffusers/guiders/classifier_free_guidance.py @@ -20,7 +20,7 @@ from .guider_utils import BaseGuidance, rescale_noise_cfg if TYPE_CHECKING: - from ..pipelines.modular_pipeline import BlockState + from ..modular_pipelines.modular_pipeline import BlockState class ClassifierFreeGuidance(BaseGuidance): diff --git a/src/diffusers/guiders/classifier_free_zero_star_guidance.py b/src/diffusers/guiders/classifier_free_zero_star_guidance.py index 220a95e54a8d..a722f2605036 100644 --- a/src/diffusers/guiders/classifier_free_zero_star_guidance.py +++ b/src/diffusers/guiders/classifier_free_zero_star_guidance.py @@ -20,7 +20,7 @@ from .guider_utils import BaseGuidance, rescale_noise_cfg if TYPE_CHECKING: - from ..pipelines.modular_pipeline import BlockState + from ..modular_pipelines.modular_pipeline import BlockState class ClassifierFreeZeroStarGuidance(BaseGuidance): diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index df544c955f33..e8e873f5c88f 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -20,7 +20,7 @@ if TYPE_CHECKING: - from ..pipelines.modular_pipeline import BlockState + from ..modular_pipelines.modular_pipeline import BlockState logger = get_logger(__name__) # pylint: disable=invalid-name @@ -171,7 +171,7 @@ def _prepare_batch(cls, input_fields: Dict[str, Union[str, Tuple[str, str]]], da Returns: `BlockState`: The prepared batch of data. """ - from ..pipelines.modular_pipeline import BlockState + from ..modular_pipelines.modular_pipeline import BlockState if input_fields is None: raise ValueError("Input fields cannot be None. Please pass `input_fields` to `prepare_inputs` or call `set_input_fields` before preparing inputs.") diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py index 56dae1903606..7c19f6391f41 100644 --- a/src/diffusers/guiders/skip_layer_guidance.py +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -22,7 +22,7 @@ from .guider_utils import BaseGuidance, rescale_noise_cfg if TYPE_CHECKING: - from ..pipelines.modular_pipeline import BlockState + from ..modular_pipelines.modular_pipeline import BlockState class SkipLayerGuidance(BaseGuidance): diff --git a/src/diffusers/guiders/smoothed_energy_guidance.py b/src/diffusers/guiders/smoothed_energy_guidance.py index c215cb0afdc9..3986da913f82 100644 --- a/src/diffusers/guiders/smoothed_energy_guidance.py +++ b/src/diffusers/guiders/smoothed_energy_guidance.py @@ -22,7 +22,7 @@ from .guider_utils import BaseGuidance, rescale_noise_cfg if TYPE_CHECKING: - from ..pipelines.modular_pipeline import BlockState + from ..modular_pipelines.modular_pipeline import BlockState class SmoothedEnergyGuidance(BaseGuidance): diff --git a/src/diffusers/guiders/tangential_classifier_free_guidance.py b/src/diffusers/guiders/tangential_classifier_free_guidance.py index 9fa8f9454134..017693fd9f07 100644 --- a/src/diffusers/guiders/tangential_classifier_free_guidance.py +++ b/src/diffusers/guiders/tangential_classifier_free_guidance.py @@ -20,7 +20,7 @@ from .guider_utils import BaseGuidance, rescale_noise_cfg if TYPE_CHECKING: - from ..pipelines.modular_pipeline import BlockState + from ..modular_pipelines.modular_pipeline import BlockState class TangentialClassifierFreeGuidance(BaseGuidance): From 462429b68747dc1c0a313bb1a8f913de207dde6d Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 10 May 2025 03:50:10 +0200 Subject: [PATCH 15/38] remove modular reelated change from pipelines folder --- src/diffusers/pipelines/modular_pipeline.py | 1916 ----------- .../pipelines/modular_pipeline_utils.py | 598 ---- .../pipeline_stable_diffusion_xl_modular.py | 3032 ----------------- ...table_diffusion_xl_modular_denoise_loop.py | 1363 -------- 4 files changed, 6909 deletions(-) delete mode 100644 src/diffusers/pipelines/modular_pipeline.py delete mode 100644 src/diffusers/pipelines/modular_pipeline_utils.py delete mode 100644 src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py delete mode 100644 src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular_denoise_loop.py diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py deleted file mode 100644 index 97a8677bda63..000000000000 --- a/src/diffusers/pipelines/modular_pipeline.py +++ /dev/null @@ -1,1916 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import traceback -import warnings -from collections import OrderedDict -from dataclasses import dataclass, field -from typing import Any, Dict, List, Tuple, Union, Optional, Type - - -import torch -from tqdm.auto import tqdm -import re -import os -import importlib - -from huggingface_hub.utils import validate_hf_hub_args - -from ..configuration_utils import ConfigMixin, FrozenDict -from ..utils import ( - is_accelerate_available, - is_accelerate_version, - logging, - PushToHubMixin, -) -from .pipeline_loading_utils import _get_pipeline_class, simple_get_class_obj,_fetch_class_library_tuple -from .modular_pipeline_utils import ( - ComponentSpec, - ConfigSpec, - InputParam, - OutputParam, - format_components, - format_configs, - format_input_params, - format_inputs_short, - format_intermediates_short, - format_output_params, - format_params, - make_doc_string, -) -from .components_manager import ComponentsManager - -from copy import deepcopy -if is_accelerate_available(): - import accelerate - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -MODULAR_LOADER_MAPPING = OrderedDict( - [ - ("stable-diffusion-xl", "StableDiffusionXLModularLoader"), - ] -) - - -@dataclass -class PipelineState: - """ - [`PipelineState`] stores the state of a pipeline. It is used to pass data between pipeline blocks. - """ - - inputs: Dict[str, Any] = field(default_factory=dict) - intermediates: Dict[str, Any] = field(default_factory=dict) - input_kwargs: Dict[str, list[str, Any]] = field(default_factory=dict) - intermediate_kwargs: Dict[str, list[str, Any]] = field(default_factory=dict) - - def add_input(self, key: str, value: Any, kwargs_type: str = None): - """ - Add an input to the pipeline state with optional metadata. - - Args: - key (str): The key for the input - value (Any): The input value - kwargs_type (str): The kwargs_type to store with the input - """ - self.inputs[key] = value - if kwargs_type is not None: - if kwargs_type not in self.input_kwargs: - self.input_kwargs[kwargs_type] = [key] - else: - self.input_kwargs[kwargs_type].append(key) - - def add_intermediate(self, key: str, value: Any, kwargs_type: str = None): - """ - Add an intermediate value to the pipeline state with optional metadata. - - Args: - key (str): The key for the intermediate value - value (Any): The intermediate value - kwargs_type (str): The kwargs_type to store with the intermediate value - """ - self.intermediates[key] = value - if kwargs_type is not None: - if kwargs_type not in self.intermediate_kwargs: - self.intermediate_kwargs[kwargs_type] = [key] - else: - self.intermediate_kwargs[kwargs_type].append(key) - - def get_input(self, key: str, default: Any = None) -> Any: - return self.inputs.get(key, default) - - def get_inputs(self, keys: List[str], default: Any = None) -> Dict[str, Any]: - return {key: self.inputs.get(key, default) for key in keys} - - def get_inputs_kwargs(self, kwargs_type: str) -> Dict[str, Any]: - """ - Get all inputs with matching kwargs_type. - - Args: - kwargs_type (str): The kwargs_type to filter by - - Returns: - Dict[str, Any]: Dictionary of inputs with matching kwargs_type - """ - input_names = self.input_kwargs.get(kwargs_type, []) - return self.get_inputs(input_names) - - def get_intermediates_kwargs(self, kwargs_type: str) -> Dict[str, Any]: - """ - Get all intermediates with matching kwargs_type. - - Args: - kwargs_type (str): The kwargs_type to filter by - - Returns: - Dict[str, Any]: Dictionary of intermediates with matching kwargs_type - """ - intermediate_names = self.intermediate_kwargs.get(kwargs_type, []) - return self.get_intermediates(intermediate_names) - - def get_intermediate(self, key: str, default: Any = None) -> Any: - return self.intermediates.get(key, default) - - def get_intermediates(self, keys: List[str], default: Any = None) -> Dict[str, Any]: - return {key: self.intermediates.get(key, default) for key in keys} - - def to_dict(self) -> Dict[str, Any]: - return {**self.__dict__, "inputs": self.inputs, "intermediates": self.intermediates} - - def __repr__(self): - def format_value(v): - if hasattr(v, "shape") and hasattr(v, "dtype"): - return f"Tensor(dtype={v.dtype}, shape={v.shape})" - elif isinstance(v, list) and len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): - return f"[Tensor(dtype={v[0].dtype}, shape={v[0].shape}), ...]" - else: - return repr(v) - - inputs = "\n".join(f" {k}: {format_value(v)}" for k, v in self.inputs.items()) - intermediates = "\n".join(f" {k}: {format_value(v)}" for k, v in self.intermediates.items()) - - # Format input_kwargs and intermediate_kwargs - input_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.input_kwargs.items()) - intermediate_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.intermediate_kwargs.items()) - - return ( - f"PipelineState(\n" - f" inputs={{\n{inputs}\n }},\n" - f" intermediates={{\n{intermediates}\n }},\n" - f" input_kwargs={{\n{input_kwargs_str}\n }},\n" - f" intermediate_kwargs={{\n{intermediate_kwargs_str}\n }}\n" - f")" - ) - - -@dataclass -class BlockState: - """ - Container for block state data with attribute access and formatted representation. - """ - def __init__(self, **kwargs): - for key, value in kwargs.items(): - setattr(self, key, value) - - def __getitem__(self, key: str): - # allows block_state["foo"] - return getattr(self, key, None) - - def __setitem__(self, key: str, value: Any): - # allows block_state["foo"] = "bar" - setattr(self, key, value) - - def as_dict(self): - """ - Convert BlockState to a dictionary. - - Returns: - Dict[str, Any]: Dictionary containing all attributes of the BlockState - """ - return {key: value for key, value in self.__dict__.items()} - - def __repr__(self): - def format_value(v): - # Handle tensors directly - if hasattr(v, "shape") and hasattr(v, "dtype"): - return f"Tensor(dtype={v.dtype}, shape={v.shape})" - - # Handle lists of tensors - elif isinstance(v, list): - if len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): - shapes = [t.shape for t in v] - return f"List[{len(v)}] of Tensors with shapes {shapes}" - return repr(v) - - # Handle tuples of tensors - elif isinstance(v, tuple): - if len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): - shapes = [t.shape for t in v] - return f"Tuple[{len(v)}] of Tensors with shapes {shapes}" - return repr(v) - - # Handle dicts with tensor values - elif isinstance(v, dict): - formatted_dict = {} - for k, val in v.items(): - if hasattr(val, "shape") and hasattr(val, "dtype"): - formatted_dict[k] = f"Tensor(shape={val.shape}, dtype={val.dtype})" - elif isinstance(val, list) and len(val) > 0 and hasattr(val[0], "shape") and hasattr(val[0], "dtype"): - shapes = [t.shape for t in val] - formatted_dict[k] = f"List[{len(val)}] of Tensors with shapes {shapes}" - else: - formatted_dict[k] = repr(val) - return formatted_dict - - # Default case - return repr(v) - - attributes = "\n".join(f" {k}: {format_value(v)}" for k, v in self.__dict__.items()) - return f"BlockState(\n{attributes}\n)" - - - -class ModularPipelineMixin: - """ - Mixin for all PipelineBlocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks - """ - - - def setup_loader(self, modular_repo: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None): - """ - create a mouldar loader, optionally accept modular_repo to load from hub. - """ - - # Import components loader (it is model-specific class) - loader_class_name = MODULAR_LOADER_MAPPING[self.model_name] - diffusers_module = importlib.import_module("diffusers") - loader_class = getattr(diffusers_module, loader_class_name) - - # Create deep copies to avoid modifying the original specs - component_specs = deepcopy(self.expected_components) - config_specs = deepcopy(self.expected_configs) - # Create the loader with the updated specs - specs = component_specs + config_specs - - self.loader = loader_class(specs, modular_repo=modular_repo, component_manager=component_manager, collection=collection) - - - @property - def default_call_parameters(self) -> Dict[str, Any]: - params = {} - for input_param in self.inputs: - params[input_param.name] = input_param.default - return params - - def run(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs): - """ - Run one or more blocks in sequence, optionally you can pass a previous pipeline state. - """ - if state is None: - state = PipelineState() - - if not hasattr(self, "loader"): - logger.warning("Loader is not set, please call `setup_loader()` if you need to load checkpoints for your pipeline.") - self.loader = None - - # Make a copy of the input kwargs - passed_kwargs = kwargs.copy() - - - # Add inputs to state, using defaults if not provided in the kwargs or the state - # if same input already in the state, will override it if provided in the kwargs - - intermediates_inputs = [inp.name for inp in self.intermediates_inputs] - for expected_input_param in self.inputs: - name = expected_input_param.name - default = expected_input_param.default - kwargs_type = expected_input_param.kwargs_type - if name in passed_kwargs: - if name not in intermediates_inputs: - state.add_input(name, passed_kwargs.pop(name), kwargs_type) - else: - state.add_input(name, passed_kwargs[name], kwargs_type) - elif name not in state.inputs: - state.add_input(name, default, kwargs_type) - - for expected_intermediate_param in self.intermediates_inputs: - name = expected_intermediate_param.name - kwargs_type = expected_intermediate_param.kwargs_type - if name in passed_kwargs: - state.add_intermediate(name, passed_kwargs.pop(name), kwargs_type) - - # Warn about unexpected inputs - if len(passed_kwargs) > 0: - logger.warning(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.") - # Run the pipeline - with torch.no_grad(): - try: - pipeline, state = self(self.loader, state) - except Exception: - error_msg = f"Error in block: ({self.__class__.__name__}):\n" - logger.error(error_msg) - raise - - if output is None: - return state - - - elif isinstance(output, str): - return state.get_intermediate(output) - - elif isinstance(output, (list, tuple)): - return state.get_intermediates(output) - else: - raise ValueError(f"Output '{output}' is not a valid output type") - - @torch.compiler.disable - def progress_bar(self, iterable=None, total=None): - if not hasattr(self, "_progress_bar_config"): - self._progress_bar_config = {} - elif not isinstance(self._progress_bar_config, dict): - raise ValueError( - f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." - ) - - if iterable is not None: - return tqdm(iterable, **self._progress_bar_config) - elif total is not None: - return tqdm(total=total, **self._progress_bar_config) - else: - raise ValueError("Either `total` or `iterable` has to be defined.") - - def set_progress_bar_config(self, **kwargs): - self._progress_bar_config = kwargs - - -class PipelineBlock(ModularPipelineMixin): - - model_name = None - - @property - def description(self) -> str: - """Description of the block. Must be implemented by subclasses.""" - raise NotImplementedError("description method must be implemented in subclasses") - - @property - def expected_components(self) -> List[ComponentSpec]: - return [] - - @property - def expected_configs(self) -> List[ConfigSpec]: - return [] - - - # YiYi TODO: can we combine inputs and intermediates_inputs? the difference is inputs are immutable - @property - def inputs(self) -> List[InputParam]: - """List of input parameters. Must be implemented by subclasses.""" - return [] - - @property - def intermediates_inputs(self) -> List[InputParam]: - """List of intermediate input parameters. Must be implemented by subclasses.""" - return [] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - """List of intermediate output parameters. Must be implemented by subclasses.""" - return [] - - # Adding outputs attributes here for consistency between PipelineBlock/AutoPipelineBlocks/SequentialPipelineBlocks - @property - def outputs(self) -> List[OutputParam]: - return self.intermediates_outputs - - @property - def required_inputs(self) -> List[str]: - input_names = [] - for input_param in self.inputs: - if input_param.required: - input_names.append(input_param.name) - return input_names - - @property - def required_intermediates_inputs(self) -> List[str]: - input_names = [] - for input_param in self.intermediates_inputs: - if input_param.required: - input_names.append(input_param.name) - return input_names - - - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - raise NotImplementedError("__call__ method must be implemented in subclasses") - - def __repr__(self): - class_name = self.__class__.__name__ - base_class = self.__class__.__bases__[0].__name__ - - # Format description with proper indentation - desc_lines = self.description.split('\n') - desc = [] - # First line with "Description:" label - desc.append(f" Description: {desc_lines[0]}") - # Subsequent lines with proper indentation - if len(desc_lines) > 1: - desc.extend(f" {line}" for line in desc_lines[1:]) - desc = '\n'.join(desc) + '\n' - - # Components section - use format_components with add_empty_lines=False - expected_components = getattr(self, "expected_components", []) - components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) - components = " " + components_str.replace("\n", "\n ") - - # Configs section - use format_configs with add_empty_lines=False - expected_configs = getattr(self, "expected_configs", []) - configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) - configs = " " + configs_str.replace("\n", "\n ") - - # Inputs section - inputs_str = format_inputs_short(self.inputs) - inputs = "Inputs:\n " + inputs_str - - # Intermediates section - intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs) - intermediates = f"Intermediates:\n{intermediates_str}" - - return ( - f"{class_name}(\n" - f" Class: {base_class}\n" - f"{desc}" - f"{components}\n" - f"{configs}\n" - f" {inputs}\n" - f" {intermediates}\n" - f")" - ) - - - @property - def doc(self): - return make_doc_string( - self.inputs, - self.intermediates_inputs, - self.outputs, - self.description, - class_name=self.__class__.__name__, - expected_components=self.expected_components, - expected_configs=self.expected_configs - ) - - - def get_block_state(self, state: PipelineState) -> dict: - """Get all inputs and intermediates in one dictionary""" - data = {} - - # Check inputs - for input_param in self.inputs: - if input_param.name: - value = state.get_input(input_param.name) - if input_param.required and value is None: - raise ValueError(f"Required input '{input_param.name}' is missing") - elif value is not None or (value is None and input_param.name not in data): - data[input_param.name] = value - elif input_param.kwargs_type: - # if kwargs_type is provided, get all inputs with matching kwargs_type - if input_param.kwargs_type not in data: - data[input_param.kwargs_type] = {} - inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type) - if inputs_kwargs: - for k, v in inputs_kwargs.items(): - if v is not None: - data[k] = v - data[input_param.kwargs_type][k] = v - - # Check intermediates - for input_param in self.intermediates_inputs: - if input_param.name: - value = state.get_intermediate(input_param.name) - if input_param.required and value is None: - raise ValueError(f"Required intermediate input '{input_param.name}' is missing") - elif value is not None or (value is None and input_param.name not in data): - data[input_param.name] = value - elif input_param.kwargs_type: - # if kwargs_type is provided, get all intermediates with matching kwargs_type - if input_param.kwargs_type not in data: - data[input_param.kwargs_type] = {} - intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type) - if intermediates_kwargs: - for k, v in intermediates_kwargs.items(): - if v is not None: - if k not in data: - data[k] = v - data[input_param.kwargs_type][k] = v - return BlockState(**data) - - def add_block_state(self, state: PipelineState, block_state: BlockState): - for output_param in self.intermediates_outputs: - if not hasattr(block_state, output_param.name): - raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") - param = getattr(block_state, output_param.name) - state.add_intermediate(output_param.name, param, output_param.kwargs_type) - - -def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]: - """ - Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if - current default value is None and new default value is not None. Warns if multiple non-None default values - exist for the same input. - - Args: - named_input_lists: List of tuples containing (block_name, input_param_list) pairs - - Returns: - List[InputParam]: Combined list of unique InputParam objects - """ - combined_dict = {} # name -> InputParam - value_sources = {} # name -> block_name - - for block_name, inputs in named_input_lists: - for input_param in inputs: - if input_param.name is None and input_param.kwargs_type is not None: - input_name = "*_" + input_param.kwargs_type - else: - input_name = input_param.name - if input_name in combined_dict: - current_param = combined_dict[input_name] - if (current_param.default is not None and - input_param.default is not None and - current_param.default != input_param.default): - warnings.warn( - f"Multiple different default values found for input '{input_param.name}': " - f"{current_param.default} (from block '{value_sources[input_param.name]}') and " - f"{input_param.default} (from block '{block_name}'). Using {current_param.default}." - ) - if current_param.default is None and input_param.default is not None: - combined_dict[input_param.name] = input_param - value_sources[input_param.name] = block_name - else: - combined_dict[input_param.name] = input_param - value_sources[input_param.name] = block_name - - return list(combined_dict.values()) - -def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> List[OutputParam]: - """ - Combines multiple lists of OutputParam objects from different blocks. For duplicate outputs, - keeps the first occurrence of each output name. - - Args: - named_output_lists: List of tuples containing (block_name, output_param_list) pairs - - Returns: - List[OutputParam]: Combined list of unique OutputParam objects - """ - combined_dict = {} # name -> OutputParam - - for block_name, outputs in named_output_lists: - for output_param in outputs: - if (output_param.name not in combined_dict) or (combined_dict[output_param.name].kwargs_type is None and output_param.kwargs_type is not None): - combined_dict[output_param.name] = output_param - - return list(combined_dict.values()) - - -class AutoPipelineBlocks(ModularPipelineMixin): - """ - A class that automatically selects a block to run based on the inputs. - - Attributes: - block_classes: List of block classes to be used - block_names: List of prefixes for each block - block_trigger_inputs: List of input names that trigger specific blocks, with None for default - """ - - block_classes = [] - block_names = [] - block_trigger_inputs = [] - - def __init__(self): - blocks = OrderedDict() - for block_name, block_cls in zip(self.block_names, self.block_classes): - blocks[block_name] = block_cls() - self.blocks = blocks - if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)): - raise ValueError(f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same.") - default_blocks = [t for t in self.block_trigger_inputs if t is None] - # can only have 1 or 0 default block, and has to put in the last - # the order of blocksmatters here because the first block with matching trigger will be dispatched - # e.g. blocks = [inpaint, img2img] and block_trigger_inputs = ["mask", "image"] - # if both mask and image are provided, it is inpaint; if only image is provided, it is img2img - if len(default_blocks) > 1 or ( - len(default_blocks) == 1 and self.block_trigger_inputs[-1] is not None - ): - raise ValueError( - f"In {self.__class__.__name__}, exactly one None must be specified as the last element " - "in block_trigger_inputs." - ) - - # Map trigger inputs to block objects - self.trigger_to_block_map = dict(zip(self.block_trigger_inputs, self.blocks.values())) - self.trigger_to_block_name_map = dict(zip(self.block_trigger_inputs, self.blocks.keys())) - self.block_to_trigger_map = dict(zip(self.blocks.keys(), self.block_trigger_inputs)) - - @property - def model_name(self): - return next(iter(self.blocks.values())).model_name - - @property - def description(self): - return "" - - @property - def expected_components(self): - expected_components = [] - for block in self.blocks.values(): - for component in block.expected_components: - if component not in expected_components: - expected_components.append(component) - return expected_components - - @property - def expected_configs(self): - expected_configs = [] - for block in self.blocks.values(): - for config in block.expected_configs: - if config not in expected_configs: - expected_configs.append(config) - return expected_configs - - - @property - def required_inputs(self) -> List[str]: - first_block = next(iter(self.blocks.values())) - required_by_all = set(getattr(first_block, "required_inputs", set())) - - # Intersect with required inputs from all other blocks - for block in list(self.blocks.values())[1:]: - block_required = set(getattr(block, "required_inputs", set())) - required_by_all.intersection_update(block_required) - - return list(required_by_all) - - @property - def required_intermediates_inputs(self) -> List[str]: - first_block = next(iter(self.blocks.values())) - required_by_all = set(getattr(first_block, "required_intermediates_inputs", set())) - - # Intersect with required inputs from all other blocks - for block in list(self.blocks.values())[1:]: - block_required = set(getattr(block, "required_intermediates_inputs", set())) - required_by_all.intersection_update(block_required) - - return list(required_by_all) - - - # YiYi TODO: add test for this - @property - def inputs(self) -> List[Tuple[str, Any]]: - named_inputs = [(name, block.inputs) for name, block in self.blocks.items()] - combined_inputs = combine_inputs(*named_inputs) - # mark Required inputs only if that input is required by all the blocks - for input_param in combined_inputs: - if input_param.name in self.required_inputs: - input_param.required = True - else: - input_param.required = False - return combined_inputs - - - @property - def intermediates_inputs(self) -> List[str]: - named_inputs = [(name, block.intermediates_inputs) for name, block in self.blocks.items()] - combined_inputs = combine_inputs(*named_inputs) - # mark Required inputs only if that input is required by all the blocks - for input_param in combined_inputs: - if input_param.name in self.required_intermediates_inputs: - input_param.required = True - else: - input_param.required = False - return combined_inputs - - @property - def intermediates_outputs(self) -> List[str]: - named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] - combined_outputs = combine_outputs(*named_outputs) - return combined_outputs - - @property - def outputs(self) -> List[str]: - named_outputs = [(name, block.outputs) for name, block in self.blocks.items()] - combined_outputs = combine_outputs(*named_outputs) - return combined_outputs - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - # Find default block first (if any) - - block = self.trigger_to_block_map.get(None) - for input_name in self.block_trigger_inputs: - if input_name is not None and state.get_input(input_name) is not None: - block = self.trigger_to_block_map[input_name] - break - elif input_name is not None and state.get_intermediate(input_name) is not None: - block = self.trigger_to_block_map[input_name] - break - - if block is None: - logger.warning(f"skipping auto block: {self.__class__.__name__}") - return pipeline, state - - try: - logger.info(f"Running block: {block.__class__.__name__}, trigger: {input_name}") - return block(pipeline, state) - except Exception as e: - error_msg = ( - f"\nError in block: {block.__class__.__name__}\n" - f"Error details: {str(e)}\n" - f"Traceback:\n{traceback.format_exc()}" - ) - logger.error(error_msg) - raise - - def _get_trigger_inputs(self): - """ - Returns a set of all unique trigger input values found in the blocks. - Returns: Set[str] containing all unique block_trigger_inputs values - """ - def fn_recursive_get_trigger(blocks): - trigger_values = set() - - if blocks is not None: - for name, block in blocks.items(): - # Check if current block has trigger inputs(i.e. auto block) - if hasattr(block, 'block_trigger_inputs') and block.block_trigger_inputs is not None: - # Add all non-None values from the trigger inputs list - trigger_values.update(t for t in block.block_trigger_inputs if t is not None) - - # If block has blocks, recursively check them - if hasattr(block, 'blocks'): - nested_triggers = fn_recursive_get_trigger(block.blocks) - trigger_values.update(nested_triggers) - - return trigger_values - - trigger_inputs = set(self.block_trigger_inputs) - trigger_inputs.update(fn_recursive_get_trigger(self.blocks)) - - return trigger_inputs - - @property - def trigger_inputs(self): - return self._get_trigger_inputs() - - def __repr__(self): - class_name = self.__class__.__name__ - base_class = self.__class__.__bases__[0].__name__ - header = ( - f"{class_name}(\n Class: {base_class}\n" - if base_class and base_class != "object" - else f"{class_name}(\n" - ) - - - if self.trigger_inputs: - header += "\n" - header += " " + "=" * 100 + "\n" - header += " This pipeline contains blocks that are selected at runtime based on inputs.\n" - header += f" Trigger Inputs: {self.trigger_inputs}\n" - # Get first trigger input as example - example_input = next(t for t in self.trigger_inputs if t is not None) - header += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n" - header += " " + "=" * 100 + "\n\n" - - # Format description with proper indentation - desc_lines = self.description.split('\n') - desc = [] - # First line with "Description:" label - desc.append(f" Description: {desc_lines[0]}") - # Subsequent lines with proper indentation - if len(desc_lines) > 1: - desc.extend(f" {line}" for line in desc_lines[1:]) - desc = '\n'.join(desc) + '\n' - - # Components section - focus only on expected components - expected_components = getattr(self, "expected_components", []) - components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) - - # Configs section - use format_configs with add_empty_lines=False - expected_configs = getattr(self, "expected_configs", []) - configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) - - # Blocks section - moved to the end with simplified format - blocks_str = " Blocks:\n" - for i, (name, block) in enumerate(self.blocks.items()): - # Get trigger input for this block - trigger = None - if hasattr(self, 'block_to_trigger_map'): - trigger = self.block_to_trigger_map.get(name) - # Format the trigger info - if trigger is None: - trigger_str = "[default]" - elif isinstance(trigger, (list, tuple)): - trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]" - else: - trigger_str = f"[trigger: {trigger}]" - # For AutoPipelineBlocks, add bullet points - blocks_str += f" • {name} {trigger_str} ({block.__class__.__name__})\n" - else: - # For SequentialPipelineBlocks, show execution order - blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" - - # Add block description - desc_lines = block.description.split('\n') - indented_desc = desc_lines[0] - if len(desc_lines) > 1: - indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) - blocks_str += f" Description: {indented_desc}\n\n" - - return ( - f"{header}\n" - f"{desc}\n\n" - f"{components_str}\n\n" - f"{configs_str}\n\n" - f"{blocks_str}" - f")" - ) - - - @property - def doc(self): - return make_doc_string( - self.inputs, - self.intermediates_inputs, - self.outputs, - self.description, - class_name=self.__class__.__name__, - expected_components=self.expected_components, - expected_configs=self.expected_configs - ) - -class SequentialPipelineBlocks(ModularPipelineMixin): - """ - A class that combines multiple pipeline block classes into one. When called, it will call each block in sequence. - """ - block_classes = [] - block_names = [] - - @property - def model_name(self): - return next(iter(self.blocks.values())).model_name - - @property - def description(self): - return "" - - @property - def expected_components(self): - expected_components = [] - for block in self.blocks.values(): - for component in block.expected_components: - if component not in expected_components: - expected_components.append(component) - return expected_components - - @property - def expected_configs(self): - expected_configs = [] - for block in self.blocks.values(): - for config in block.expected_configs: - if config not in expected_configs: - expected_configs.append(config) - return expected_configs - - @classmethod - def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "SequentialPipelineBlocks": - """Creates a SequentialPipelineBlocks instance from a dictionary of blocks. - - Args: - blocks_dict: Dictionary mapping block names to block instances - - Returns: - A new SequentialPipelineBlocks instance - """ - instance = cls() - instance.block_classes = [block.__class__ for block in blocks_dict.values()] - instance.block_names = list(blocks_dict.keys()) - instance.blocks = blocks_dict - return instance - - def __init__(self): - blocks = OrderedDict() - for block_name, block_cls in zip(self.block_names, self.block_classes): - blocks[block_name] = block_cls() - self.blocks = blocks - - - @property - def required_inputs(self) -> List[str]: - # Get the first block from the dictionary - first_block = next(iter(self.blocks.values())) - required_by_any = set(getattr(first_block, "required_inputs", set())) - - # Union with required inputs from all other blocks - for block in list(self.blocks.values())[1:]: - block_required = set(getattr(block, "required_inputs", set())) - required_by_any.update(block_required) - - return list(required_by_any) - - @property - def required_intermediates_inputs(self) -> List[str]: - required_intermediates_inputs = [] - for input_param in self.intermediates_inputs: - if input_param.required: - required_intermediates_inputs.append(input_param.name) - return required_intermediates_inputs - - # YiYi TODO: add test for this - @property - def inputs(self) -> List[Tuple[str, Any]]: - return self.get_inputs() - - def get_inputs(self): - named_inputs = [(name, block.inputs) for name, block in self.blocks.items()] - combined_inputs = combine_inputs(*named_inputs) - # mark Required inputs only if that input is required any of the blocks - for input_param in combined_inputs: - if input_param.name in self.required_inputs: - input_param.required = True - else: - input_param.required = False - return combined_inputs - - @property - def intermediates_inputs(self) -> List[str]: - return self.get_intermediates_inputs() - - def get_intermediates_inputs(self): - inputs = [] - outputs = set() - - # Go through all blocks in order - for block in self.blocks.values(): - # Add inputs that aren't in outputs yet - inputs.extend(input_name for input_name in block.intermediates_inputs if input_name.name not in outputs) - - # Only add outputs if the block cannot be skipped - should_add_outputs = True - if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs: - should_add_outputs = False - - if should_add_outputs: - # Add this block's outputs - block_intermediates_outputs = [out.name for out in block.intermediates_outputs] - outputs.update(block_intermediates_outputs) - return inputs - - @property - def intermediates_outputs(self) -> List[str]: - named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] - combined_outputs = combine_outputs(*named_outputs) - return combined_outputs - - @property - def outputs(self) -> List[str]: - return next(reversed(self.blocks.values())).intermediates_outputs - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - for block_name, block in self.blocks.items(): - try: - pipeline, state = block(pipeline, state) - except Exception as e: - error_msg = ( - f"\nError in block: ({block_name}, {block.__class__.__name__})\n" - f"Error details: {str(e)}\n" - f"Traceback:\n{traceback.format_exc()}" - ) - logger.error(error_msg) - raise - return pipeline, state - - def _get_trigger_inputs(self): - """ - Returns a set of all unique trigger input values found in the blocks. - Returns: Set[str] containing all unique block_trigger_inputs values - """ - def fn_recursive_get_trigger(blocks): - trigger_values = set() - - if blocks is not None: - for name, block in blocks.items(): - # Check if current block has trigger inputs(i.e. auto block) - if hasattr(block, 'block_trigger_inputs') and block.block_trigger_inputs is not None: - # Add all non-None values from the trigger inputs list - trigger_values.update(t for t in block.block_trigger_inputs if t is not None) - - # If block has blocks, recursively check them - if hasattr(block, 'blocks'): - nested_triggers = fn_recursive_get_trigger(block.blocks) - trigger_values.update(nested_triggers) - - return trigger_values - - return fn_recursive_get_trigger(self.blocks) - - @property - def trigger_inputs(self): - return self._get_trigger_inputs() - - def _traverse_trigger_blocks(self, trigger_inputs): - # Convert trigger_inputs to a set for easier manipulation - active_triggers = set(trigger_inputs) - def fn_recursive_traverse(block, block_name, active_triggers): - result_blocks = OrderedDict() - - # sequential(include loopsequential) or PipelineBlock - if not hasattr(block, 'block_trigger_inputs'): - if hasattr(block, 'blocks'): - # sequential or LoopSequentialPipelineBlocks (keep traversing) - for sub_block_name, sub_block in block.blocks.items(): - blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers) - blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers) - blocks_to_update = {f"{block_name}.{k}": v for k,v in blocks_to_update.items()} - result_blocks.update(blocks_to_update) - else: - # PipelineBlock - result_blocks[block_name] = block - # Add this block's output names to active triggers if defined - if hasattr(block, 'outputs'): - active_triggers.update(out.name for out in block.outputs) - return result_blocks - - # auto - else: - # Find first block_trigger_input that matches any value in our active_triggers - this_block = None - matching_trigger = None - for trigger_input in block.block_trigger_inputs: - if trigger_input is not None and trigger_input in active_triggers: - this_block = block.trigger_to_block_map[trigger_input] - matching_trigger = trigger_input - break - - # If no matches found, try to get the default (None) block - if this_block is None and None in block.block_trigger_inputs: - this_block = block.trigger_to_block_map[None] - matching_trigger = None - - if this_block is not None: - # sequential/auto (keep traversing) - if hasattr(this_block, 'blocks'): - result_blocks.update(fn_recursive_traverse(this_block, block_name, active_triggers)) - else: - # PipelineBlock - result_blocks[block_name] = this_block - # Add this block's output names to active triggers if defined - # YiYi TODO: do we need outputs here? can it just be intermediate_outputs? can we get rid of outputs attribute? - if hasattr(this_block, 'outputs'): - active_triggers.update(out.name for out in this_block.outputs) - - return result_blocks - - all_blocks = OrderedDict() - for block_name, block in self.blocks.items(): - blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers) - all_blocks.update(blocks_to_update) - return all_blocks - - def get_execution_blocks(self, *trigger_inputs): - trigger_inputs_all = self.trigger_inputs - - if trigger_inputs is not None: - - if not isinstance(trigger_inputs, (list, tuple, set)): - trigger_inputs = [trigger_inputs] - invalid_inputs = [x for x in trigger_inputs if x not in trigger_inputs_all] - if invalid_inputs: - logger.warning( - f"The following trigger inputs will be ignored as they are not supported: {invalid_inputs}" - ) - trigger_inputs = [x for x in trigger_inputs if x in trigger_inputs_all] - - if trigger_inputs is None: - if None in trigger_inputs_all: - trigger_inputs = [None] - else: - trigger_inputs = [trigger_inputs_all[0]] - blocks_triggered = self._traverse_trigger_blocks(trigger_inputs) - return SequentialPipelineBlocks.from_blocks_dict(blocks_triggered) - - def __repr__(self): - class_name = self.__class__.__name__ - base_class = self.__class__.__bases__[0].__name__ - header = ( - f"{class_name}(\n Class: {base_class}\n" - if base_class and base_class != "object" - else f"{class_name}(\n" - ) - - - if self.trigger_inputs: - header += "\n" - header += " " + "=" * 100 + "\n" - header += " This pipeline contains blocks that are selected at runtime based on inputs.\n" - header += f" Trigger Inputs: {self.trigger_inputs}\n" - # Get first trigger input as example - example_input = next(t for t in self.trigger_inputs if t is not None) - header += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n" - header += " " + "=" * 100 + "\n\n" - - # Format description with proper indentation - desc_lines = self.description.split('\n') - desc = [] - # First line with "Description:" label - desc.append(f" Description: {desc_lines[0]}") - # Subsequent lines with proper indentation - if len(desc_lines) > 1: - desc.extend(f" {line}" for line in desc_lines[1:]) - desc = '\n'.join(desc) + '\n' - - # Components section - focus only on expected components - expected_components = getattr(self, "expected_components", []) - components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) - - # Configs section - use format_configs with add_empty_lines=False - expected_configs = getattr(self, "expected_configs", []) - configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) - - # Blocks section - moved to the end with simplified format - blocks_str = " Blocks:\n" - for i, (name, block) in enumerate(self.blocks.items()): - # Get trigger input for this block - trigger = None - if hasattr(self, 'block_to_trigger_map'): - trigger = self.block_to_trigger_map.get(name) - # Format the trigger info - if trigger is None: - trigger_str = "[default]" - elif isinstance(trigger, (list, tuple)): - trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]" - else: - trigger_str = f"[trigger: {trigger}]" - # For AutoPipelineBlocks, add bullet points - blocks_str += f" • {name} {trigger_str} ({block.__class__.__name__})\n" - else: - # For SequentialPipelineBlocks, show execution order - blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" - - # Add block description - desc_lines = block.description.split('\n') - indented_desc = desc_lines[0] - if len(desc_lines) > 1: - indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) - blocks_str += f" Description: {indented_desc}\n\n" - - return ( - f"{header}\n" - f"{desc}\n\n" - f"{components_str}\n\n" - f"{configs_str}\n\n" - f"{blocks_str}" - f")" - ) - - - @property - def doc(self): - return make_doc_string( - self.inputs, - self.intermediates_inputs, - self.outputs, - self.description, - class_name=self.__class__.__name__, - expected_components=self.expected_components, - expected_configs=self.expected_configs - ) - -#YiYi TODO: __repr__ -class LoopSequentialPipelineBlocks(ModularPipelineMixin): - """ - A class that combines multiple pipeline block classes into a For Loop. When called, it will call each block in sequence. - """ - - model_name = None - block_classes = [] - block_names = [] - - @property - def description(self) -> str: - """Description of the block. Must be implemented by subclasses.""" - raise NotImplementedError("description method must be implemented in subclasses") - - @property - def loop_expected_components(self) -> List[ComponentSpec]: - return [] - - @property - def loop_expected_configs(self) -> List[ConfigSpec]: - return [] - - @property - def loop_inputs(self) -> List[InputParam]: - """List of input parameters. Must be implemented by subclasses.""" - return [] - - @property - def loop_intermediates_inputs(self) -> List[InputParam]: - """List of intermediate input parameters. Must be implemented by subclasses.""" - return [] - - @property - def loop_intermediates_outputs(self) -> List[OutputParam]: - """List of intermediate output parameters. Must be implemented by subclasses.""" - return [] - - - @property - def loop_required_inputs(self) -> List[str]: - input_names = [] - for input_param in self.loop_inputs: - if input_param.required: - input_names.append(input_param.name) - return input_names - - @property - def loop_required_intermediates_inputs(self) -> List[str]: - input_names = [] - for input_param in self.loop_intermediates_inputs: - if input_param.required: - input_names.append(input_param.name) - return input_names - - # modified from SequentialPipelineBlocks to include loop_expected_components - @property - def expected_components(self): - expected_components = [] - for block in self.blocks.values(): - for component in block.expected_components: - if component not in expected_components: - expected_components.append(component) - for component in self.loop_expected_components: - if component not in expected_components: - expected_components.append(component) - return expected_components - - # modified from SequentialPipelineBlocks to include loop_expected_configs - @property - def expected_configs(self): - expected_configs = [] - for block in self.blocks.values(): - for config in block.expected_configs: - if config not in expected_configs: - expected_configs.append(config) - for config in self.loop_expected_configs: - if config not in expected_configs: - expected_configs.append(config) - return expected_configs - - # modified from SequentialPipelineBlocks to include loop_inputs - def get_inputs(self): - named_inputs = [(name, block.inputs) for name, block in self.blocks.items()] - named_inputs.append(("loop", self.loop_inputs)) - combined_inputs = combine_inputs(*named_inputs) - # mark Required inputs only if that input is required any of the blocks - for input_param in combined_inputs: - if input_param.name in self.required_inputs: - input_param.required = True - else: - input_param.required = False - return combined_inputs - - # Copied from SequentialPipelineBlocks - @property - def inputs(self): - return self.get_inputs() - - - # modified from SequentialPipelineBlocks to include loop_intermediates_inputs - @property - def intermediates_inputs(self): - intermediates = self.get_intermediates_inputs() - intermediate_names = [input.name for input in intermediates] - for loop_intermediate_input in self.loop_intermediates_inputs: - if loop_intermediate_input.name not in intermediate_names: - intermediates.append(loop_intermediate_input) - return intermediates - - - # Copied from SequentialPipelineBlocks - def get_intermediates_inputs(self): - inputs = [] - outputs = set() - - # Go through all blocks in order - for block in self.blocks.values(): - # Add inputs that aren't in outputs yet - inputs.extend(input_name for input_name in block.intermediates_inputs if input_name.name not in outputs) - - # Only add outputs if the block cannot be skipped - should_add_outputs = True - if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs: - should_add_outputs = False - - if should_add_outputs: - # Add this block's outputs - block_intermediates_outputs = [out.name for out in block.intermediates_outputs] - outputs.update(block_intermediates_outputs) - return inputs - - - # modified from SequentialPipelineBlocks, if any additionan input required by the loop is required by the block - @property - def required_inputs(self) -> List[str]: - # Get the first block from the dictionary - first_block = next(iter(self.blocks.values())) - required_by_any = set(getattr(first_block, "required_inputs", set())) - - required_by_loop = set(getattr(self, "loop_required_inputs", set())) - required_by_any.update(required_by_loop) - - # Union with required inputs from all other blocks - for block in list(self.blocks.values())[1:]: - block_required = set(getattr(block, "required_inputs", set())) - required_by_any.update(block_required) - - return list(required_by_any) - - # modified from SequentialPipelineBlocks, if any additional intermediate input required by the loop is required by the block - @property - def required_intermediates_inputs(self) -> List[str]: - required_intermediates_inputs = [] - for input_param in self.intermediates_inputs: - if input_param.required: - required_intermediates_inputs.append(input_param.name) - for input_param in self.loop_intermediates_inputs: - if input_param.required: - required_intermediates_inputs.append(input_param.name) - return required_intermediates_inputs - - - # YiYi TODO: this need to be thought about more - # modified from SequentialPipelineBlocks to include loop_intermediates_outputs - @property - def intermediates_outputs(self) -> List[str]: - named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] - combined_outputs = combine_outputs(*named_outputs) - for output in self.loop_intermediates_outputs: - if output.name not in set([output.name for output in combined_outputs]): - combined_outputs.append(output) - return combined_outputs - - # YiYi TODO: this need to be thought about more - # copied from SequentialPipelineBlocks - @property - def outputs(self) -> List[str]: - return next(reversed(self.blocks.values())).intermediates_outputs - - - def __init__(self): - blocks = OrderedDict() - for block_name, block_cls in zip(self.block_names, self.block_classes): - blocks[block_name] = block_cls() - self.blocks = blocks - - def loop_step(self, components, state: PipelineState, **kwargs): - - for block_name, block in self.blocks.items(): - try: - components, state = block(components, state, **kwargs) - except Exception as e: - error_msg = ( - f"\nError in block: ({block_name}, {block.__class__.__name__})\n" - f"Error details: {str(e)}\n" - f"Traceback:\n{traceback.format_exc()}" - ) - logger.error(error_msg) - raise - return components, state - - def __call__(self, components, state: PipelineState) -> PipelineState: - raise NotImplementedError("`__call__` method needs to be implemented by the subclass") - - - def get_block_state(self, state: PipelineState) -> dict: - """Get all inputs and intermediates in one dictionary""" - data = {} - - # Check inputs - for input_param in self.inputs: - if input_param.name: - value = state.get_input(input_param.name) - if input_param.required and value is None: - raise ValueError(f"Required input '{input_param.name}' is missing") - elif value is not None or (value is None and input_param.name not in data): - data[input_param.name] = value - elif input_param.kwargs_type: - # if kwargs_type is provided, get all inputs with matching kwargs_type - if input_param.kwargs_type not in data: - data[input_param.kwargs_type] = {} - inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type) - if inputs_kwargs: - for k, v in inputs_kwargs.items(): - if v is not None: - data[k] = v - data[input_param.kwargs_type][k] = v - - # Check intermediates - for input_param in self.intermediates_inputs: - if input_param.name: - value = state.get_intermediate(input_param.name) - if input_param.required and value is None: - raise ValueError(f"Required intermediate input '{input_param.name}' is missing") - elif value is not None or (value is None and input_param.name not in data): - data[input_param.name] = value - elif input_param.kwargs_type: - # if kwargs_type is provided, get all intermediates with matching kwargs_type - if input_param.kwargs_type not in data: - data[input_param.kwargs_type] = {} - intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type) - if intermediates_kwargs: - for k, v in intermediates_kwargs.items(): - if v is not None: - if k not in data: - data[k] = v - data[input_param.kwargs_type][k] = v - return BlockState(**data) - - def add_block_state(self, state: PipelineState, block_state: BlockState): - for output_param in self.intermediates_outputs: - if not hasattr(block_state, output_param.name): - raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") - param = getattr(block_state, output_param.name) - state.add_intermediate(output_param.name, param, output_param.kwargs_type) - -# YiYi TODO: -# 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess) -# 2. do we need ConfigSpec? seems pretty unnecessrary for loader, can just add and kwargs to the loader -# 3. add validator for methods where we accpet kwargs to be passed to from_pretrained() -class ModularLoader(ConfigMixin, PushToHubMixin): - """ - Base class for all Modular pipelines loaders. - - """ - config_name = "modular_model_index.json" - - - def register_components(self, **kwargs): - """ - Register components with their corresponding specs. - This method is called when component changed or __init__ is called. - - Args: - **kwargs: Keyword arguments where keys are component names and values are component objects. - - """ - for name, module in kwargs.items(): - - # current component spec - component_spec = self._component_specs.get(name) - if component_spec is None: - logger.warning(f"ModularLoader.register_components: skipping unknown component '{name}'") - continue - - is_registered = hasattr(self, name) - - if module is not None and not hasattr(module, "_diffusers_load_id"): - raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") - - # actual library and class name of the module - - if module is not None: - library, class_name = _fetch_class_library_tuple(module) - new_component_spec = ComponentSpec.from_component(name, module) - component_spec_dict = self._component_spec_to_dict(new_component_spec) - - else: - library, class_name = None, None - # if module is None, we do not update the spec, - # but we still need to update the config to make sure it's synced with the component spec - # (in the case of the first time registration, we initilize the object with component spec, and then we call register_components() to register it to config) - new_component_spec = component_spec - component_spec_dict = self._component_spec_to_dict(component_spec) - - # do not register if component is not to be loaded from pretrained - if new_component_spec.default_creation_method == "from_pretrained": - register_dict = {name: (library, class_name, component_spec_dict)} - else: - register_dict = {} - - # set the component as attribute - # if it is not set yet, just set it and skip the process to check and warn below - if not is_registered: - self.register_to_config(**register_dict) - self._component_specs[name] = new_component_spec - setattr(self, name, module) - if module is not None and self._component_manager is not None: - self._component_manager.add(name, module, self._collection) - continue - - current_module = getattr(self, name, None) - # skip if the component is already registered with the same object - if current_module is module: - logger.info(f"ModularLoader.register_components: {name} is already registered with same object, skipping") - continue - - # it module is not an instance of the expected type, still register it but with a warning - if module is not None and component_spec.type_hint is not None and not isinstance(module, component_spec.type_hint): - logger.warning(f"ModularLoader.register_components: adding {name} with new type: {module.__class__.__name__}, previous type: {component_spec.type_hint.__name__}") - - # warn if unregister - if current_module is not None and module is None: - logger.info( - f"ModularLoader.register_components: setting '{name}' to None " - f"(was {current_module.__class__.__name__})" - ) - # same type, new instance → debug - elif current_module is not None \ - and module is not None \ - and isinstance(module, current_module.__class__) \ - and current_module != module: - logger.debug( - f"ModularLoader.register_components: replacing existing '{name}' " - f"(same type {type(current_module).__name__}, new instance)" - ) - - # save modular_model_index.json config - self.register_to_config(**register_dict) - # update component spec - self._component_specs[name] = new_component_spec - # finally set models - setattr(self, name, module) - if module is not None and self._component_manager is not None: - self._component_manager.add(name, module, self._collection) - - - - # YiYi TODO: add warning for passing multiple ComponentSpec/ConfigSpec with the same name - def __init__(self, specs: List[Union[ComponentSpec, ConfigSpec]], modular_repo: Optional[str] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs): - """ - Initialize the loader with a list of component specs and config specs. - """ - self._component_manager = component_manager - self._collection = collection - self._component_specs = { - spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ComponentSpec) - } - self._config_specs = { - spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ConfigSpec) - } - - # update component_specs and config_specs from modular_repo - if modular_repo is not None: - config_dict = self.load_config(modular_repo, **kwargs) - - for name, value in config_dict.items(): - if name in self._component_specs and self._component_specs[name].default_creation_method == "from_pretrained" and isinstance(value, (tuple, list)) and len(value) == 3: - library, class_name, component_spec_dict = value - component_spec = self._dict_to_component_spec(name, component_spec_dict) - self._component_specs[name] = component_spec - - elif name in self._config_specs: - self._config_specs[name].default = value - - register_components_dict = {} - for name, component_spec in self._component_specs.items(): - register_components_dict[name] = None - self.register_components(**register_components_dict) - - default_configs = {} - for name, config_spec in self._config_specs.items(): - default_configs[name] = config_spec.default - self.register_to_config(**default_configs) - - - @property - def device(self) -> torch.device: - r""" - Returns: - `torch.device`: The torch device on which the pipeline is located. - """ - modules = self.components.values() - modules = [m for m in modules if isinstance(m, torch.nn.Module)] - - for module in modules: - return module.device - - return torch.device("cpu") - - @property - # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline._execution_device - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - [`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from - Accelerate's module hooks. - """ - for name, model in self.components.items(): - if not isinstance(model, torch.nn.Module): - continue - - if not hasattr(model, "_hf_hook"): - return self.device - for module in model.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - @property - def device(self) -> torch.device: - r""" - Returns: - `torch.device`: The torch device on which the pipeline is located. - """ - - modules = [m for m in self.components.values() if isinstance(m, torch.nn.Module)] - - for module in modules: - return module.device - - return torch.device("cpu") - - @property - def dtype(self) -> torch.dtype: - r""" - Returns: - `torch.dtype`: The torch dtype on which the pipeline is located. - """ - modules = self.components.values() - modules = [m for m in modules if isinstance(m, torch.nn.Module)] - - for module in modules: - return module.dtype - - return torch.float32 - - - @property - def components(self) -> Dict[str, Any]: - # return only components we've actually set as attributes on self - return { - name: getattr(self, name) - for name in self._component_specs.keys() - if hasattr(self, name) - } - - def update(self, **kwargs): - """ - Update components and configs after instance creation. - - Args: - - """ - """ - Update components and configuration values after the loader has been instantiated. - - This method allows you to: - 1. Replace existing components with new ones (e.g., updating the unet or text_encoder) - 2. Update configuration values (e.g., changing requires_safety_checker flag) - - Args: - **kwargs: Component objects or configuration values to update: - - Component objects: Must be created using ComponentSpec (e.g., `unet=new_unet, text_encoder=new_encoder`) - - Configuration values: Simple values to update configuration settings (e.g., `requires_safety_checker=False`) - - Raises: - ValueError: If a component wasn't created using ComponentSpec (doesn't have `_diffusers_load_id` attribute) - - Examples: - ```python - # Update multiple components at once - loader.update( - unet=new_unet_model, - text_encoder=new_text_encoder - ) - - # Update configuration values - loader.update( - requires_safety_checker=False, - guidance_rescale=0.7 - ) - - # Update both components and configs together - loader.update( - unet=new_unet_model, - requires_safety_checker=False - ) - ``` - """ - - # extract component_specs_updates & config_specs_updates from `specs` - passed_components = {k: kwargs.pop(k) for k in self._component_specs if k in kwargs} - passed_config_values = {k: kwargs.pop(k) for k in self._config_specs if k in kwargs} - - for name, component in passed_components.items(): - if not hasattr(component, "_diffusers_load_id"): - raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") - - if len(kwargs) > 0: - logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}") - - - self.register_components(**passed_components) - - - config_to_register = {} - for name, new_value in passed_config_values.items(): - - # e.g. requires_aesthetics_score = False - self._config_specs[name].default = new_value - config_to_register[name] = new_value - self.register_to_config(**config_to_register) - - - # YiYi TODO: support map for additional from_pretrained kwargs - def load(self, component_names: Optional[List[str]] = None, **kwargs): - """ - Load selectedcomponents from specs. - - Args: - component_names: List of component names to load - **kwargs: additional kwargs to be passed to `from_pretrained()`.Can be: - - a single value to be applied to all components to be loaded, e.g. torch_dtype=torch.bfloat16 - - a dict, e.g. torch_dtype={"unet": torch.bfloat16, "default": torch.float32} - - if potentially override ComponentSpec if passed a different loading field in kwargs, e.g. `repo`, `variant`, `revision`, etc. - """ - if component_names is None: - component_names = list(self._component_specs.keys()) - elif not isinstance(component_names, list): - component_names = [component_names] - - components_to_load = set([name for name in component_names if name in self._component_specs]) - unknown_component_names = set([name for name in component_names if name not in self._component_specs]) - if len(unknown_component_names) > 0: - logger.warning(f"Unknown components will be ignored: {unknown_component_names}") - - components_to_register = {} - for name in components_to_load: - spec = self._component_specs[name] - component_load_kwargs = {} - for key, value in kwargs.items(): - if not isinstance(value, dict): - # if the value is a single value, apply it to all components - component_load_kwargs[key] = value - else: - if name in value: - # if it is a dict, check if the component name is in the dict - component_load_kwargs[key] = value[name] - elif "default" in value: - # check if the default is specified - component_load_kwargs[key] = value["default"] - try: - components_to_register[name] = spec.create(**component_load_kwargs) - except Exception as e: - logger.warning(f"Failed to create component '{name}': {e}") - - # Register all components at once - self.register_components(**components_to_register) - - # YiYi TODO: should support to method - def to(self, *args, **kwargs): - pass - - # YiYi TODO: - # 1. should support save some components too! currently only modular_model_index.json is saved - # 2. maybe order the json file to make it more readable: configs first, then components - def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, spec_only: bool = True, **kwargs): - - component_names = list(self._component_specs.keys()) - config_names = list(self._config_specs.keys()) - self.register_to_config(_components_names=component_names, _configs_names=config_names) - self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) - config = dict(self.config) - config.pop("_components_names", None) - config.pop("_configs_names", None) - self._internal_dict = FrozenDict(config) - - - @classmethod - @validate_hf_hub_args - def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], spec_only: bool = True, **kwargs): - - config_dict = cls.load_config(pretrained_model_name_or_path, **kwargs) - expected_component = set(config_dict.pop("_components_names")) - expected_config = set(config_dict.pop("_configs_names")) - - component_specs = [] - config_specs = [] - for name, value in config_dict.items(): - if name in expected_component and isinstance(value, (tuple, list)) and len(value) == 3: - library, class_name, component_spec_dict = value - component_spec = cls._dict_to_component_spec(name, component_spec_dict) - component_specs.append(component_spec) - - elif name in expected_config: - config_specs.append(ConfigSpec(name=name, default=value)) - - for name in expected_component: - for spec in component_specs: - if spec.name == name: - break - else: - # append a empty component spec for these not in modular_model_index - component_specs.append(ComponentSpec(name=name, default_creation_method="from_config")) - return cls(component_specs + config_specs) - - - @staticmethod - def _component_spec_to_dict(component_spec: ComponentSpec) -> Any: - """ - Convert a ComponentSpec into a JSON‐serializable dict for saving in - `modular_model_index.json`. - - This dict contains: - - "type_hint": Tuple[str, str] - The fully‐qualified module path and class name of the component. - - All loading fields defined by `component_spec.loading_fields()`, typically: - - "repo": Optional[str] - The model repository (e.g., "stabilityai/stable-diffusion-xl"). - - "subfolder": Optional[str] - A subfolder within the repo where this component lives. - - "variant": Optional[str] - An optional variant identifier for the model. - - "revision": Optional[str] - A specific git revision (commit hash, tag, or branch). - - ... any other loading fields defined on the spec. - - Args: - component_spec (ComponentSpec): - The spec object describing one pipeline component. - - Returns: - Dict[str, Any]: A mapping suitable for JSON serialization. - - Example: - >>> from diffusers.pipelines.modular_pipeline_utils import ComponentSpec - >>> from diffusers.models.unet import UNet2DConditionModel - >>> spec = ComponentSpec( - ... name="unet", - ... type_hint=UNet2DConditionModel, - ... config=None, - ... repo="path/to/repo", - ... subfolder="subfolder", - ... variant=None, - ... revision=None, - ... default_creation_method="from_pretrained", - ... ) - >>> ModularLoader._component_spec_to_dict(spec) - { - "type_hint": ("diffusers.models.unet", "UNet2DConditionModel"), - "repo": "path/to/repo", - "subfolder": "subfolder", - "variant": None, - "revision": None, - } - """ - if component_spec.type_hint is not None: - lib_name, cls_name = _fetch_class_library_tuple(component_spec.type_hint) - else: - lib_name = None - cls_name = None - load_spec_dict = {k: getattr(component_spec, k) for k in component_spec.loading_fields()} - return { - "type_hint": (lib_name, cls_name), - **load_spec_dict, - } - - @staticmethod - def _dict_to_component_spec( - name: str, - spec_dict: Dict[str, Any], - ) -> ComponentSpec: - """ - Reconstruct a ComponentSpec from a dict. - """ - # make a shallow copy so we can pop() safely - spec_dict = spec_dict.copy() - # pull out and resolve the stored type_hint - lib_name, cls_name = spec_dict.pop("type_hint") - if lib_name is not None and cls_name is not None: - type_hint = simple_get_class_obj(lib_name, cls_name) - else: - type_hint = None - - # re‐assemble the ComponentSpec - return ComponentSpec( - name=name, - type_hint=type_hint, - **spec_dict, - ) \ No newline at end of file diff --git a/src/diffusers/pipelines/modular_pipeline_utils.py b/src/diffusers/pipelines/modular_pipeline_utils.py deleted file mode 100644 index 392d6dcd9521..000000000000 --- a/src/diffusers/pipelines/modular_pipeline_utils.py +++ /dev/null @@ -1,598 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import re -import inspect -from dataclasses import dataclass, asdict, field, fields -from typing import Any, Dict, List, Optional, Tuple, Type, Union, Literal - -from ..utils.import_utils import is_torch_available -from ..configuration_utils import FrozenDict, ConfigMixin - -if is_torch_available(): - import torch - - -# YiYi TODO: -# 1. validate the dataclass fields -# 2. add a validator for create_* methods, make sure they are valid inputs to pass to from_pretrained() -@dataclass -class ComponentSpec: - """Specification for a pipeline component. - - A component can be created in two ways: - 1. From scratch using __init__ with a config dict - 2. using `from_pretrained` - - Attributes: - name: Name of the component - type_hint: Type of the component (e.g. UNet2DConditionModel) - description: Optional description of the component - config: Optional config dict for __init__ creation - repo: Optional repo path for from_pretrained creation - subfolder: Optional subfolder in repo - variant: Optional variant in repo - revision: Optional revision in repo - default_creation_method: Preferred creation method - "from_config" or "from_pretrained" - """ - name: Optional[str] = None - type_hint: Optional[Type] = None - description: Optional[str] = None - config: Optional[FrozenDict[str, Any]] = None - # YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name - repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True}) - subfolder: Optional[str] = field(default=None, metadata={"loading": True}) - variant: Optional[str] = field(default=None, metadata={"loading": True}) - revision: Optional[str] = field(default=None, metadata={"loading": True}) - default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained" - - - def __hash__(self): - """Make ComponentSpec hashable, using load_id as the hash value.""" - return hash((self.name, self.load_id, self.default_creation_method)) - - def __eq__(self, other): - """Compare ComponentSpec objects based on name and load_id.""" - if not isinstance(other, ComponentSpec): - return False - return (self.name == other.name and - self.load_id == other.load_id and - self.default_creation_method == other.default_creation_method) - - @classmethod - def from_component(cls, name: str, component: torch.nn.Module) -> Any: - """Create a ComponentSpec from a Component created by `create` method.""" - - if not hasattr(component, "_diffusers_load_id"): - raise ValueError("Component is not created by `create` method") - - type_hint = component.__class__ - - if component._diffusers_load_id == "null" and isinstance(component, ConfigMixin): - config = component.config - else: - config = None - - load_spec = cls.decode_load_id(component._diffusers_load_id) - - return cls(name=name, type_hint=type_hint, config=config, **load_spec) - - @classmethod - def from_load_id(cls, load_id: str, name: Optional[str] = None) -> Any: - """Create a ComponentSpec from a load_id string.""" - if load_id == "null": - raise ValueError("Cannot create ComponentSpec from null load_id") - - # Decode the load_id into a dictionary of loading fields - load_fields = cls.decode_load_id(load_id) - - # Create a new ComponentSpec instance with the decoded fields - return cls(name=name, **load_fields) - - @classmethod - def loading_fields(cls) -> List[str]: - """ - Return the names of all loading‐related fields - (i.e. those whose field.metadata["loading"] is True). - """ - return [f.name for f in fields(cls) if f.metadata.get("loading", False)] - - - @property - def load_id(self) -> str: - """ - Unique identifier for this spec's pretrained load, - composed of repo|subfolder|variant|revision (no empty segments). - """ - parts = [getattr(self, k) for k in self.loading_fields()] - parts = ["null" if p is None else p for p in parts] - return "|".join(p for p in parts if p) - - @classmethod - def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]: - """ - Decode a load_id string back into a dictionary of loading fields and values. - - Args: - load_id: The load_id string to decode, format: "repo|subfolder|variant|revision" - where None values are represented as "null" - - Returns: - Dict mapping loading field names to their values. e.g. - { - "repo": "path/to/repo", - "subfolder": "subfolder", - "variant": "variant", - "revision": "revision" - } - If a segment value is "null", it's replaced with None. - Returns None if load_id is "null" (indicating component not loaded from pretrained). - """ - - # Get all loading fields in order - loading_fields = cls.loading_fields() - result = {f: None for f in loading_fields} - - if load_id == "null": - return result - - # Split the load_id - parts = load_id.split("|") - - # Map parts to loading fields by position - for i, part in enumerate(parts): - if i < len(loading_fields): - # Convert "null" string back to None - result[loading_fields[i]] = None if part == "null" else part - - return result - - # YiYi TODO: add validator - def create(self, **kwargs) -> Any: - """Create the component using the preferred creation method.""" - - # from_pretrained creation - if self.default_creation_method == "from_pretrained": - return self.create_from_pretrained(**kwargs) - elif self.default_creation_method == "from_config": - # from_config creation - return self.create_from_config(**kwargs) - else: - raise ValueError(f"Invalid creation method: {self.default_creation_method}") - - def create_from_config(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any: - """Create component using from_config with config.""" - - if self.type_hint is None or not isinstance(self.type_hint, type): - raise ValueError( - f"`type_hint` is required when using from_config creation method." - ) - - config = config or self.config or {} - - if issubclass(self.type_hint, ConfigMixin): - component = self.type_hint.from_config(config, **kwargs) - else: - signature_params = inspect.signature(self.type_hint.__init__).parameters - init_kwargs = {} - for k, v in config.items(): - if k in signature_params: - init_kwargs[k] = v - for k, v in kwargs.items(): - if k in signature_params: - init_kwargs[k] = v - component = self.type_hint(**init_kwargs) - - component._diffusers_load_id = "null" - if hasattr(component, "config"): - self.config = component.config - - return component - - # YiYi TODO: add guard for type of model, if it is supported by from_pretrained - def create_from_pretrained(self, **kwargs) -> Any: - """Create component using from_pretrained.""" - - passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs} - load_kwargs = {key: passed_loading_kwargs.get(key, getattr(self, key)) for key in self.loading_fields()} - # repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path - repo = load_kwargs.pop("repo", None) - if repo is None: - raise ValueError(f"`repo` info is required when using from_pretrained creation method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)") - - if self.type_hint is None: - try: - from diffusers import AutoModel - component = AutoModel.from_pretrained(repo, **load_kwargs, **kwargs) - except Exception as e: - raise ValueError(f"Error creating {self.name} without `type_hint` from pretrained: {e}") - self.type_hint = component.__class__ - else: - try: - component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs) - except Exception as e: - raise ValueError(f"Error creating {self.name}[{self.type_hint.__name__}] from pretrained: {e}") - - if repo != self.repo: - self.repo = repo - for k, v in passed_loading_kwargs.items(): - if v is not None: - setattr(self, k, v) - component._diffusers_load_id = self.load_id - - return component - - - -@dataclass -class ConfigSpec: - """Specification for a pipeline configuration parameter.""" - name: str - default: Any - description: Optional[str] = None -@dataclass -class InputParam: - """Specification for an input parameter.""" - name: str = None - type_hint: Any = None - default: Any = None - required: bool = False - description: str = "" - kwargs_type: str = None - - def __repr__(self): - return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>" - - -@dataclass -class OutputParam: - """Specification for an output parameter.""" - name: str - type_hint: Any = None - description: str = "" - kwargs_type: str = None - - def __repr__(self): - return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>" - - -def format_inputs_short(inputs): - """ - Format input parameters into a string representation, with required params first followed by optional ones. - - Args: - inputs: List of input parameters with 'required' and 'name' attributes, and 'default' for optional params - - Returns: - str: Formatted string of input parameters - - Example: - >>> inputs = [ - ... InputParam(name="prompt", required=True), - ... InputParam(name="image", required=True), - ... InputParam(name="guidance_scale", required=False, default=7.5), - ... InputParam(name="num_inference_steps", required=False, default=50) - ... ] - >>> format_inputs_short(inputs) - 'prompt, image, guidance_scale=7.5, num_inference_steps=50' - """ - required_inputs = [param for param in inputs if param.required] - optional_inputs = [param for param in inputs if not param.required] - - required_str = ", ".join(param.name for param in required_inputs) - optional_str = ", ".join(f"{param.name}={param.default}" for param in optional_inputs) - - inputs_str = required_str - if optional_str: - inputs_str = f"{inputs_str}, {optional_str}" if required_str else optional_str - - return inputs_str - - -def format_intermediates_short(intermediates_inputs, required_intermediates_inputs, intermediates_outputs): - """ - Formats intermediate inputs and outputs of a block into a string representation. - - Args: - intermediates_inputs: List of intermediate input parameters - required_intermediates_inputs: List of required intermediate input names - intermediates_outputs: List of intermediate output parameters - - Returns: - str: Formatted string like: - Intermediates: - - inputs: Required(latents), dtype - - modified: latents # variables that appear in both inputs and outputs - - outputs: images # new outputs only - """ - # Handle inputs - input_parts = [] - for inp in intermediates_inputs: - if inp.name in required_intermediates_inputs: - input_parts.append(f"Required({inp.name})") - else: - if inp.name is None and inp.kwargs_type is not None: - inp_name = "*_" + inp.kwargs_type - else: - inp_name = inp.name - input_parts.append(inp_name) - - # Handle modified variables (appear in both inputs and outputs) - inputs_set = {inp.name for inp in intermediates_inputs} - modified_parts = [] - new_output_parts = [] - - for out in intermediates_outputs: - if out.name in inputs_set: - modified_parts.append(out.name) - else: - new_output_parts.append(out.name) - - result = [] - if input_parts: - result.append(f" - inputs: {', '.join(input_parts)}") - if modified_parts: - result.append(f" - modified: {', '.join(modified_parts)}") - if new_output_parts: - result.append(f" - outputs: {', '.join(new_output_parts)}") - - return "\n".join(result) if result else " (none)" - - -def format_params(params, header="Args", indent_level=4, max_line_length=115): - """Format a list of InputParam or OutputParam objects into a readable string representation. - - Args: - params: List of InputParam or OutputParam objects to format - header: Header text to use (e.g. "Args" or "Returns") - indent_level: Number of spaces to indent each parameter line (default: 4) - max_line_length: Maximum length for each line before wrapping (default: 115) - - Returns: - A formatted string representing all parameters - """ - if not params: - return "" - - base_indent = " " * indent_level - param_indent = " " * (indent_level + 4) - desc_indent = " " * (indent_level + 8) - formatted_params = [] - - def get_type_str(type_hint): - if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Union: - types = [t.__name__ if hasattr(t, "__name__") else str(t) for t in type_hint.__args__] - return f"Union[{', '.join(types)}]" - return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint) - - def wrap_text(text, indent, max_length): - """Wrap text while preserving markdown links and maintaining indentation.""" - words = text.split() - lines = [] - current_line = [] - current_length = 0 - - for word in words: - word_length = len(word) + (1 if current_line else 0) - - if current_line and current_length + word_length > max_length: - lines.append(" ".join(current_line)) - current_line = [word] - current_length = len(word) - else: - current_line.append(word) - current_length += word_length - - if current_line: - lines.append(" ".join(current_line)) - - return f"\n{indent}".join(lines) - - # Add the header - formatted_params.append(f"{base_indent}{header}:") - - for param in params: - # Format parameter name and type - type_str = get_type_str(param.type_hint) if param.type_hint != Any else "" - param_str = f"{param_indent}{param.name} (`{type_str}`" - - # Add optional tag and default value if parameter is an InputParam and optional - if hasattr(param, "required"): - if not param.required: - param_str += ", *optional*" - if param.default is not None: - param_str += f", defaults to {param.default}" - param_str += "):" - - # Add description on a new line with additional indentation and wrapping - if param.description: - desc = re.sub( - r'\[(.*?)\]\((https?://[^\s\)]+)\)', - r'[\1](\2)', - param.description - ) - wrapped_desc = wrap_text(desc, desc_indent, max_line_length) - param_str += f"\n{desc_indent}{wrapped_desc}" - - formatted_params.append(param_str) - - return "\n\n".join(formatted_params) - - -def format_input_params(input_params, indent_level=4, max_line_length=115): - """Format a list of InputParam objects into a readable string representation. - - Args: - input_params: List of InputParam objects to format - indent_level: Number of spaces to indent each parameter line (default: 4) - max_line_length: Maximum length for each line before wrapping (default: 115) - - Returns: - A formatted string representing all input parameters - """ - return format_params(input_params, "Inputs", indent_level, max_line_length) - - -def format_output_params(output_params, indent_level=4, max_line_length=115): - """Format a list of OutputParam objects into a readable string representation. - - Args: - output_params: List of OutputParam objects to format - indent_level: Number of spaces to indent each parameter line (default: 4) - max_line_length: Maximum length for each line before wrapping (default: 115) - - Returns: - A formatted string representing all output parameters - """ - return format_params(output_params, "Outputs", indent_level, max_line_length) - - -def format_components(components, indent_level=4, max_line_length=115, add_empty_lines=True): - """Format a list of ComponentSpec objects into a readable string representation. - - Args: - components: List of ComponentSpec objects to format - indent_level: Number of spaces to indent each component line (default: 4) - max_line_length: Maximum length for each line before wrapping (default: 115) - add_empty_lines: Whether to add empty lines between components (default: True) - - Returns: - A formatted string representing all components - """ - if not components: - return "" - - base_indent = " " * indent_level - component_indent = " " * (indent_level + 4) - formatted_components = [] - - # Add the header - formatted_components.append(f"{base_indent}Components:") - if add_empty_lines: - formatted_components.append("") - - # Add each component with optional empty lines between them - for i, component in enumerate(components): - # Get type name, handling special cases - type_name = component.type_hint.__name__ if hasattr(component.type_hint, "__name__") else str(component.type_hint) - - component_desc = f"{component_indent}{component.name} (`{type_name}`)" - if component.description: - component_desc += f": {component.description}" - - # Get the loading fields dynamically - loading_field_values = [] - for field_name in component.loading_fields(): - field_value = getattr(component, field_name) - if field_value is not None: - loading_field_values.append(f"{field_name}={field_value}") - - # Add loading field information if available - if loading_field_values: - component_desc += f" [{', '.join(loading_field_values)}]" - - formatted_components.append(component_desc) - - # Add an empty line after each component except the last one - if add_empty_lines and i < len(components) - 1: - formatted_components.append("") - - return "\n".join(formatted_components) - - -def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines=True): - """Format a list of ConfigSpec objects into a readable string representation. - - Args: - configs: List of ConfigSpec objects to format - indent_level: Number of spaces to indent each config line (default: 4) - max_line_length: Maximum length for each line before wrapping (default: 115) - add_empty_lines: Whether to add empty lines between configs (default: True) - - Returns: - A formatted string representing all configs - """ - if not configs: - return "" - - base_indent = " " * indent_level - config_indent = " " * (indent_level + 4) - formatted_configs = [] - - # Add the header - formatted_configs.append(f"{base_indent}Configs:") - if add_empty_lines: - formatted_configs.append("") - - # Add each config with optional empty lines between them - for i, config in enumerate(configs): - config_desc = f"{config_indent}{config.name} (default: {config.default})" - if config.description: - config_desc += f": {config.description}" - formatted_configs.append(config_desc) - - # Add an empty line after each config except the last one - if add_empty_lines and i < len(configs) - 1: - formatted_configs.append("") - - return "\n".join(formatted_configs) - - -def make_doc_string(inputs, intermediates_inputs, outputs, description="", class_name=None, expected_components=None, expected_configs=None): - """ - Generates a formatted documentation string describing the pipeline block's parameters and structure. - - Args: - inputs: List of input parameters - intermediates_inputs: List of intermediate input parameters - outputs: List of output parameters - description (str, *optional*): Description of the block - class_name (str, *optional*): Name of the class to include in the documentation - expected_components (List[ComponentSpec], *optional*): List of expected components - expected_configs (List[ConfigSpec], *optional*): List of expected configurations - - Returns: - str: A formatted string containing information about components, configs, call parameters, - intermediate inputs/outputs, and final outputs. - """ - output = "" - - # Add class name if provided - if class_name: - output += f"class {class_name}\n\n" - - # Add description - if description: - desc_lines = description.strip().split('\n') - aligned_desc = '\n'.join(' ' + line for line in desc_lines) - output += aligned_desc + "\n\n" - - # Add components section if provided - if expected_components and len(expected_components) > 0: - components_str = format_components(expected_components, indent_level=2) - output += components_str + "\n\n" - - # Add configs section if provided - if expected_configs and len(expected_configs) > 0: - configs_str = format_configs(expected_configs, indent_level=2) - output += configs_str + "\n\n" - - # Add inputs section - output += format_input_params(inputs + intermediates_inputs, indent_level=2) - - # Add outputs section - output += "\n\n" - output += format_output_params(outputs, indent_level=2) - - return output \ No newline at end of file diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py deleted file mode 100644 index acb395345086..000000000000 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ /dev/null @@ -1,3032 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Any, List, Optional, Tuple, Union, Dict - -import PIL -import torch -from collections import OrderedDict - -from ...image_processor import VaeImageProcessor, PipelineImageInput -from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin -from ...models import ControlNetModel, ImageProjection, UNet2DConditionModel, AutoencoderKL, ControlNetUnionModel -from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor -from ...models.lora import adjust_lora_scale_text_encoder -from ...utils import ( - USE_PEFT_BACKEND, - logging, - scale_lora_layers, - unscale_lora_layers, -) -from ...utils.torch_utils import randn_tensor, unwrap_module -from ..controlnet.multicontrolnet import MultiControlNetModel -from ..modular_pipeline import ( - AutoPipelineBlocks, - ModularLoader, - PipelineBlock, - PipelineState, - InputParam, - OutputParam, - SequentialPipelineBlocks, - ComponentSpec, - ConfigSpec, -) -from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin -from .pipeline_output import ( - StableDiffusionXLPipelineOutput, -) - -from transformers import ( - CLIPTextModel, - CLIPImageProcessor, - CLIPTextModelWithProjection, - CLIPTokenizer, - CLIPVisionModelWithProjection, -) - -from ...schedulers import EulerDiscreteScheduler -from ...guiders import ClassifierFreeGuidance -from ...configuration_utils import FrozenDict - -import numpy as np - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - - -# YiYi TODO: move to a different file? stable_diffusion_xl_module should have its own folder? -# YiYi Notes: model specific components: -## (1) it should inherit from ModularLoader -## (2) acts like a container that holds components and configs -## (3) define default config (related to components), e.g. default_sample_size, vae_scale_factor, num_channels_unet, num_channels_latents -## (4) inherit from model-specic loader class (e.g. StableDiffusionXLLoraLoaderMixin) -## (5) how to use together with Components_manager? -class StableDiffusionXLModularLoader( - ModularLoader, - StableDiffusionMixin, - TextualInversionLoaderMixin, - StableDiffusionXLLoraLoaderMixin, - ModularIPAdapterMixin, -): - @property - def default_sample_size(self): - default_sample_size = 128 - if hasattr(self, "unet") and self.unet is not None: - default_sample_size = self.unet.config.sample_size - return default_sample_size - - @property - def vae_scale_factor(self): - vae_scale_factor = 8 - if hasattr(self, "vae") and self.vae is not None: - vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - return vae_scale_factor - - @property - def num_channels_unet(self): - num_channels_unet = 4 - if hasattr(self, "unet") and self.unet is not None: - num_channels_unet = self.unet.config.in_channels - return num_channels_unet - - @property - def num_channels_latents(self): - num_channels_latents = 4 - if hasattr(self, "vae") and self.vae is not None: - num_channels_latents = self.vae.config.latent_channels - return num_channels_latents - - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps -def retrieve_timesteps( - scheduler, - num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, - timesteps: Optional[List[int]] = None, - sigmas: Optional[List[float]] = None, - **kwargs, -): - r""" - Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles - custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. - - Args: - scheduler (`SchedulerMixin`): - The scheduler to get timesteps from. - num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` - must be `None`. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - timesteps (`List[int]`, *optional*): - Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, - `num_inference_steps` and `sigmas` must be `None`. - sigmas (`List[float]`, *optional*): - Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, - `num_inference_steps` and `timesteps` must be `None`. - - Returns: - `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the - second element is the number of inference steps. - """ - if timesteps is not None and sigmas is not None: - raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") - if timesteps is not None: - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accepts_timesteps: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" timestep schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - elif sigmas is not None: - accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accept_sigmas: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" sigmas schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - else: - scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) - timesteps = scheduler.timesteps - return timesteps, num_inference_steps - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents -def retrieve_latents( - encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" -): - if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": - return encoder_output.latent_dist.sample(generator) - elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": - return encoder_output.latent_dist.mode() - elif hasattr(encoder_output, "latents"): - return encoder_output.latents - else: - raise AttributeError("Could not access latents of provided encoder_output") - - - -class StableDiffusionXLIPAdapterStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - - @property - def description(self) -> str: - return ( - "IP Adapter step that handles all the ip adapter related tasks: Load/unload ip adapter weights into unet, prepare ip adapter image embeddings, etc" - " See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)" - " for more details" - ) - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("image_encoder", CLIPVisionModelWithProjection), - ComponentSpec("feature_extractor", CLIPImageProcessor, config=FrozenDict({"size": 224, "crop_size": 224}), default_creation_method="from_config"), - ComponentSpec("unet", UNet2DConditionModel), - ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config"), - ] - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam( - "ip_adapter_image", - PipelineImageInput, - required=True, - description="The image(s) to be used as ip adapter" - ) - ] - - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [ - OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"), - OutputParam("negative_ip_adapter_embeds", type_hint=torch.Tensor, description="Negative IP adapter image embeddings") - ] - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image with self -> components - @staticmethod - def encode_image(components, image, device, num_images_per_prompt, output_hidden_states=None): - dtype = next(components.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = components.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = components.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = components.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states - else: - image_embeds = components.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) - - return image_embeds, uncond_image_embeds - - # modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds - def prepare_ip_adapter_image_embeds( - self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, prepare_unconditional_embeds - ): - image_embeds = [] - if prepare_unconditional_embeds: - negative_image_embeds = [] - if ip_adapter_image_embeds is None: - if not isinstance(ip_adapter_image, list): - ip_adapter_image = [ip_adapter_image] - - if len(ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers): - raise ValueError( - f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." - ) - - for single_ip_adapter_image, image_proj_layer in zip( - ip_adapter_image, components.unet.encoder_hid_proj.image_projection_layers - ): - output_hidden_state = not isinstance(image_proj_layer, ImageProjection) - single_image_embeds, single_negative_image_embeds = self.encode_image( - components, single_ip_adapter_image, device, 1, output_hidden_state - ) - - image_embeds.append(single_image_embeds[None, :]) - if prepare_unconditional_embeds: - negative_image_embeds.append(single_negative_image_embeds[None, :]) - else: - for single_image_embeds in ip_adapter_image_embeds: - if prepare_unconditional_embeds: - single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - negative_image_embeds.append(single_negative_image_embeds) - image_embeds.append(single_image_embeds) - - ip_adapter_image_embeds = [] - for i, single_image_embeds in enumerate(image_embeds): - single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) - if prepare_unconditional_embeds: - single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) - - single_image_embeds = single_image_embeds.to(device=device) - ip_adapter_image_embeds.append(single_image_embeds) - - return ip_adapter_image_embeds - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 - block_state.device = components._execution_device - - block_state.ip_adapter_embeds = self.prepare_ip_adapter_image_embeds( - components, - ip_adapter_image=block_state.ip_adapter_image, - ip_adapter_image_embeds=None, - device=block_state.device, - num_images_per_prompt=1, - prepare_unconditional_embeds=block_state.prepare_unconditional_embeds, - ) - if block_state.prepare_unconditional_embeds: - block_state.negative_ip_adapter_embeds = [] - for i, image_embeds in enumerate(block_state.ip_adapter_embeds): - negative_image_embeds, image_embeds = image_embeds.chunk(2) - block_state.negative_ip_adapter_embeds.append(negative_image_embeds) - block_state.ip_adapter_embeds[i] = image_embeds - - self.add_block_state(state, block_state) - return components, state - - -class StableDiffusionXLTextEncoderStep(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return( - "Text Encoder step that generate text_embeddings to guide the image generation" - ) - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("text_encoder", CLIPTextModel), - ComponentSpec("text_encoder_2", CLIPTextModelWithProjection), - ComponentSpec("tokenizer", CLIPTokenizer), - ComponentSpec("tokenizer_2", CLIPTokenizer), - ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config"), - ] - - @property - def expected_configs(self) -> List[ConfigSpec]: - return [ConfigSpec("force_zeros_for_empty_prompt", True)] - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam("prompt"), - InputParam("prompt_2"), - InputParam("negative_prompt"), - InputParam("negative_prompt_2"), - InputParam("cross_attention_kwargs"), - InputParam("clip_skip"), - ] - - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [ - OutputParam("prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields",description="text embeddings used to guide the image generation"), - OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative text embeddings used to guide the image generation"), - OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="pooled text embeddings used to guide the image generation"), - OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative pooled text embeddings used to guide the image generation"), - ] - - @staticmethod - def check_inputs(block_state): - - if block_state.prompt is not None and (not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}") - elif block_state.prompt_2 is not None and (not isinstance(block_state.prompt_2, str) and not isinstance(block_state.prompt_2, list)): - raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(block_state.prompt_2)}") - - @staticmethod - def encode_prompt( - components, - prompt: str, - prompt_2: Optional[str] = None, - device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, - prepare_unconditional_embeds: bool = True, - negative_prompt: Optional[str] = None, - negative_prompt_2: Optional[str] = None, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - pooled_prompt_embeds: Optional[torch.Tensor] = None, - negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, - lora_scale: Optional[float] = None, - clip_skip: Optional[int] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in both text-encoders - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - prepare_unconditional_embeds (`bool`): - whether to use prepare unconditional embeddings or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - negative_prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and - `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - pooled_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` - input argument. - lora_scale (`float`, *optional*): - A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - clip_skip (`int`, *optional*): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - """ - device = device or components._execution_device - - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(components, StableDiffusionXLLoraLoaderMixin): - components._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if components.text_encoder is not None: - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(components.text_encoder, lora_scale) - else: - scale_lora_layers(components.text_encoder, lora_scale) - - if components.text_encoder_2 is not None: - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(components.text_encoder_2, lora_scale) - else: - scale_lora_layers(components.text_encoder_2, lora_scale) - - prompt = [prompt] if isinstance(prompt, str) else prompt - - if prompt is not None: - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - # Define tokenizers and text encoders - tokenizers = [components.tokenizer, components.tokenizer_2] if components.tokenizer is not None else [components.tokenizer_2] - text_encoders = ( - [components.text_encoder, components.text_encoder_2] if components.text_encoder is not None else [components.text_encoder_2] - ) - - if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - - # textual inversion: process multi-vector tokens if necessary - prompt_embeds_list = [] - prompts = [prompt, prompt_2] - for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): - if isinstance(components, TextualInversionLoaderMixin): - prompt = components.maybe_convert_prompt(prompt, tokenizer) - - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {tokenizer.model_max_length} tokens: {removed_text}" - ) - - prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) - - # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] - if clip_skip is None: - prompt_embeds = prompt_embeds.hidden_states[-2] - else: - # "2" because SDXL always indexes from the penultimate layer. - prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] - - prompt_embeds_list.append(prompt_embeds) - - prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) - - # get unconditional embeddings for classifier free guidance - zero_out_negative_prompt = negative_prompt is None and components.config.force_zeros_for_empty_prompt - if prepare_unconditional_embeds and negative_prompt_embeds is None and zero_out_negative_prompt: - negative_prompt_embeds = torch.zeros_like(prompt_embeds) - negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) - elif prepare_unconditional_embeds and negative_prompt_embeds is None: - negative_prompt = negative_prompt or "" - negative_prompt_2 = negative_prompt_2 or negative_prompt - - # normalize str to list - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_2 = ( - batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 - ) - - uncond_tokens: List[str] - if prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = [negative_prompt, negative_prompt_2] - - negative_prompt_embeds_list = [] - for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): - if isinstance(components, TextualInversionLoaderMixin): - negative_prompt = components.maybe_convert_prompt(negative_prompt, tokenizer) - - max_length = prompt_embeds.shape[1] - uncond_input = tokenizer( - negative_prompt, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - negative_prompt_embeds = text_encoder( - uncond_input.input_ids.to(device), - output_hidden_states=True, - ) - # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] - negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] - - negative_prompt_embeds_list.append(negative_prompt_embeds) - - negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - - if components.text_encoder_2 is not None: - prompt_embeds = prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device) - else: - prompt_embeds = prompt_embeds.to(dtype=components.unet.dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - if prepare_unconditional_embeds: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - if components.text_encoder_2 is not None: - negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device) - else: - negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.unet.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) - if prepare_unconditional_embeds: - negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) - - if components.text_encoder is not None: - if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(components.text_encoder, lora_scale) - - if components.text_encoder_2 is not None: - if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(components.text_encoder_2, lora_scale) - - return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds - - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - # Get inputs and intermediates - block_state = self.get_block_state(state) - self.check_inputs(block_state) - - block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 - block_state.device = components._execution_device - - # Encode input prompt - block_state.text_encoder_lora_scale = ( - block_state.cross_attention_kwargs.get("scale", None) if block_state.cross_attention_kwargs is not None else None - ) - ( - block_state.prompt_embeds, - block_state.negative_prompt_embeds, - block_state.pooled_prompt_embeds, - block_state.negative_pooled_prompt_embeds, - ) = self.encode_prompt( - components, - block_state.prompt, - block_state.prompt_2, - block_state.device, - 1, - block_state.prepare_unconditional_embeds, - block_state.negative_prompt, - block_state.negative_prompt_2, - prompt_embeds=None, - negative_prompt_embeds=None, - pooled_prompt_embeds=None, - negative_pooled_prompt_embeds=None, - lora_scale=block_state.text_encoder_lora_scale, - clip_skip=block_state.clip_skip, - ) - # Add outputs - self.add_block_state(state, block_state) - return components, state - - -class StableDiffusionXLVaeEncoderStep(PipelineBlock): - - model_name = "stable-diffusion-xl" - - - @property - def description(self) -> str: - return ( - "Vae Encoder step that encode the input image into a latent representation" - ) - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("vae", AutoencoderKL), - ComponentSpec( - "image_processor", - VaeImageProcessor, - config=FrozenDict({"vae_scale_factor": 8}), - default_creation_method="from_config"), - ] - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam("image", required=True), - InputParam("generator"), - InputParam("height"), - InputParam("width"), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [ - InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), - InputParam("preprocess_kwargs", type_hint=Optional[dict], description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]")] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation")] - - # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components - # YiYi TODO: update the _encode_vae_image so that we can use #Coped from - def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator): - - latents_mean = latents_std = None - if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: - latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: - latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) - - dtype = image.dtype - if components.vae.config.force_upcast: - image = image.float() - components.vae.to(dtype=torch.float32) - - if isinstance(generator, list): - image_latents = [ - retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(image.shape[0]) - ] - image_latents = torch.cat(image_latents, dim=0) - else: - image_latents = retrieve_latents(components.vae.encode(image), generator=generator) - - if components.vae.config.force_upcast: - components.vae.to(dtype) - - image_latents = image_latents.to(dtype) - if latents_mean is not None and latents_std is not None: - latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) - latents_std = latents_std.to(device=image_latents.device, dtype=dtype) - image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std - else: - image_latents = components.vae.config.scaling_factor * image_latents - - return image_latents - - - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - block_state.preprocess_kwargs = block_state.preprocess_kwargs or {} - block_state.device = components._execution_device - block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype - - block_state.image = components.image_processor.preprocess(block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs) - block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype) - - block_state.batch_size = block_state.image.shape[0] - - # if generator is a list, make sure the length of it matches the length of images (both should be batch_size) - if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch" - f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators." - ) - - - block_state.image_latents = self._encode_vae_image(components, image=block_state.image, generator=block_state.generator) - - self.add_block_state(state, block_state) - - return components, state - - -class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("vae", AutoencoderKL), - ComponentSpec( - "image_processor", - VaeImageProcessor, - config=FrozenDict({"vae_scale_factor": 8}), - default_creation_method="from_config"), - ComponentSpec( - "mask_processor", - VaeImageProcessor, - config=FrozenDict({"do_normalize": False, "vae_scale_factor": 8, "do_binarize": True, "do_convert_grayscale": True}), - default_creation_method="from_config"), - ] - - - @property - def description(self) -> str: - return ( - "Vae encoder step that prepares the image and mask for the inpainting process" - ) - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam("height"), - InputParam("width"), - InputParam("generator"), - InputParam("image", required=True), - InputParam("mask_image", required=True), - InputParam("padding_mask_crop"), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs")] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representation of the input image"), - OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"), - OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)"), - OutputParam("crops_coords", type_hint=Optional[Tuple[int, int]], description="The crop coordinates to use for the preprocess/postprocess of the image and mask")] - - # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components - # YiYi TODO: update the _encode_vae_image so that we can use #Coped from - def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator): - - latents_mean = latents_std = None - if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: - latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: - latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) - - dtype = image.dtype - if components.vae.config.force_upcast: - image = image.float() - components.vae.to(dtype=torch.float32) - - if isinstance(generator, list): - image_latents = [ - retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(image.shape[0]) - ] - image_latents = torch.cat(image_latents, dim=0) - else: - image_latents = retrieve_latents(components.vae.encode(image), generator=generator) - - if components.vae.config.force_upcast: - components.vae.to(dtype) - - image_latents = image_latents.to(dtype) - if latents_mean is not None and latents_std is not None: - latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) - latents_std = latents_std.to(device=image_latents.device, dtype=dtype) - image_latents = (image_latents - latents_mean) * self.vae.config.scaling_factor / latents_std - else: - image_latents = components.vae.config.scaling_factor * image_latents - - return image_latents - - # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents - # do not accept do_classifier_free_guidance - def prepare_mask_latents( - self, components, mask, masked_image, batch_size, height, width, dtype, device, generator - ): - # resize the mask to latents shape as we concatenate the mask to the latents - # we do that before converting to dtype to avoid breaking in case we're using cpu_offload - # and half precision - mask = torch.nn.functional.interpolate( - mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor) - ) - mask = mask.to(device=device, dtype=dtype) - - # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method - if mask.shape[0] < batch_size: - if not batch_size % mask.shape[0] == 0: - raise ValueError( - "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" - f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" - " of masks that you pass is divisible by the total requested batch size." - ) - mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) - - if masked_image is not None and masked_image.shape[1] == 4: - masked_image_latents = masked_image - else: - masked_image_latents = None - - if masked_image is not None: - if masked_image_latents is None: - masked_image = masked_image.to(device=device, dtype=dtype) - masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator) - - if masked_image_latents.shape[0] < batch_size: - if not batch_size % masked_image_latents.shape[0] == 0: - raise ValueError( - "The passed images and the required batch size don't match. Images are supposed to be duplicated" - f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." - " Make sure the number of images that you pass is divisible by the total requested batch size." - ) - masked_image_latents = masked_image_latents.repeat( - batch_size // masked_image_latents.shape[0], 1, 1, 1 - ) - - # aligning device to prevent device errors when concating it with the latent model input - masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) - - return mask, masked_image_latents - - - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - - block_state = self.get_block_state(state) - - block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype - block_state.device = components._execution_device - - if block_state.padding_mask_crop is not None: - block_state.crops_coords = components.mask_processor.get_crop_region(block_state.mask_image, block_state.width, block_state.height, pad=block_state.padding_mask_crop) - block_state.resize_mode = "fill" - else: - block_state.crops_coords = None - block_state.resize_mode = "default" - - block_state.image = components.image_processor.preprocess(block_state.image, height=block_state.height, width=block_state.width, crops_coords=block_state.crops_coords, resize_mode=block_state.resize_mode) - block_state.image = block_state.image.to(dtype=torch.float32) - - block_state.mask = components.mask_processor.preprocess(block_state.mask_image, height=block_state.height, width=block_state.width, resize_mode=block_state.resize_mode, crops_coords=block_state.crops_coords) - block_state.masked_image = block_state.image * (block_state.mask < 0.5) - - block_state.batch_size = block_state.image.shape[0] - block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype) - block_state.image_latents = self._encode_vae_image(components, image=block_state.image, generator=block_state.generator) - - # 7. Prepare mask latent variables - block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents( - components, - block_state.mask, - block_state.masked_image, - block_state.batch_size, - block_state.height, - block_state.width, - block_state.dtype, - block_state.device, - block_state.generator, - ) - - self.add_block_state(state, block_state) - - - return components, state - - -class StableDiffusionXLInputStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return ( - "Input processing step that:\n" - " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" - " 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`\n\n" - "All input tensors are expected to have either batch_size=1 or match the batch_size\n" - "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n" - "have a final batch_size of batch_size * num_images_per_prompt." - ) - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam("num_images_per_prompt", default=1), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam("prompt_embeds", required=True, type_hint=torch.Tensor, description="Pre-generated text embeddings. Can be generated from text_encoder step."), - InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Pre-generated negative text embeddings. Can be generated from text_encoder step."), - InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor, description="Pre-generated pooled text embeddings. Can be generated from text_encoder step."), - InputParam("negative_pooled_prompt_embeds", description="Pre-generated negative pooled text embeddings. Can be generated from text_encoder step."), - InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Pre-generated image embeddings for IP-Adapter. Can be generated from ip_adapter step."), - InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Pre-generated negative image embeddings for IP-Adapter. Can be generated from ip_adapter step."), - ] - - @property - def intermediates_outputs(self) -> List[str]: - return [ - OutputParam("batch_size", type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt"), - OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs (determined by `prompt_embeds`)"), - OutputParam("prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="text embeddings used to guide the image generation"), - OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative text embeddings used to guide the image generation"), - OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="pooled text embeddings used to guide the image generation"), - OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative pooled text embeddings used to guide the image generation"), - OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], kwargs_type="guider_input_fields", description="image embeddings for IP-Adapter"), - OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], kwargs_type="guider_input_fields", description="negative image embeddings for IP-Adapter"), - ] - - def check_inputs(self, components, block_state): - - if block_state.prompt_embeds is not None and block_state.negative_prompt_embeds is not None: - if block_state.prompt_embeds.shape != block_state.negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `negative_prompt_embeds`" - f" {block_state.negative_prompt_embeds.shape}." - ) - - if block_state.prompt_embeds is not None and block_state.pooled_prompt_embeds is None: - raise ValueError( - "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." - ) - - if block_state.negative_prompt_embeds is not None and block_state.negative_pooled_prompt_embeds is None: - raise ValueError( - "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." - ) - - if block_state.ip_adapter_embeds is not None and not isinstance(block_state.ip_adapter_embeds, list): - raise ValueError("`ip_adapter_embeds` must be a list") - - if block_state.negative_ip_adapter_embeds is not None and not isinstance(block_state.negative_ip_adapter_embeds, list): - raise ValueError("`negative_ip_adapter_embeds` must be a list") - - if block_state.ip_adapter_embeds is not None and block_state.negative_ip_adapter_embeds is not None: - for i, ip_adapter_embed in enumerate(block_state.ip_adapter_embeds): - if ip_adapter_embed.shape != block_state.negative_ip_adapter_embeds[i].shape: - raise ValueError( - "`ip_adapter_embeds` and `negative_ip_adapter_embeds` must have the same shape when passed directly, but" - f" got: `ip_adapter_embeds` {ip_adapter_embed.shape} != `negative_ip_adapter_embeds`" - f" {block_state.negative_ip_adapter_embeds[i].shape}." - ) - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - self.check_inputs(components, block_state) - - block_state.batch_size = block_state.prompt_embeds.shape[0] - block_state.dtype = block_state.prompt_embeds.dtype - - _, seq_len, _ = block_state.prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) - block_state.prompt_embeds = block_state.prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1) - - if block_state.negative_prompt_embeds is not None: - _, seq_len, _ = block_state.negative_prompt_embeds.shape - block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) - block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1) - - block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) - block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1) - - if block_state.negative_pooled_prompt_embeds is not None: - block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) - block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1) - - if block_state.ip_adapter_embeds is not None: - for i, ip_adapter_embed in enumerate(block_state.ip_adapter_embeds): - block_state.ip_adapter_embeds[i] = torch.cat([ip_adapter_embed] * block_state.num_images_per_prompt, dim=0) - - if block_state.negative_ip_adapter_embeds is not None: - for i, negative_ip_adapter_embed in enumerate(block_state.negative_ip_adapter_embeds): - block_state.negative_ip_adapter_embeds[i] = torch.cat([negative_ip_adapter_embed] * block_state.num_images_per_prompt, dim=0) - - self.add_block_state(state, block_state) - - return components, state - - -class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("scheduler", EulerDiscreteScheduler), - ] - - @property - def description(self) -> str: - return ( - "Step that sets the timesteps for the scheduler and determines the initial noise level (latent_timestep) for image-to-image/inpainting generation.\n" + \ - "The latent_timestep is calculated from the `strength` parameter - higher strength means starting from a noisier version of the input image." - ) - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam("num_inference_steps", default=50), - InputParam("timesteps"), - InputParam("sigmas"), - InputParam("denoising_end"), - InputParam("strength", default=0.3), - InputParam("denoising_start"), - # YiYi TODO: do we need num_images_per_prompt here? - InputParam("num_images_per_prompt", default=1), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt"), - ] - - @property - def intermediates_outputs(self) -> List[str]: - return [ - OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), - OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time"), - OutputParam("latent_timestep", type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image generation") - ] - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps with self -> components - def get_timesteps(self, components, num_inference_steps, strength, device, denoising_start=None): - # get the original timestep using init_timestep - if denoising_start is None: - init_timestep = min(int(num_inference_steps * strength), num_inference_steps) - t_start = max(num_inference_steps - init_timestep, 0) - - timesteps = components.scheduler.timesteps[t_start * components.scheduler.order :] - if hasattr(components.scheduler, "set_begin_index"): - components.scheduler.set_begin_index(t_start * components.scheduler.order) - - return timesteps, num_inference_steps - t_start - - else: - # Strength is irrelevant if we directly request a timestep to start at; - # that is, strength is determined by the denoising_start instead. - discrete_timestep_cutoff = int( - round( - components.scheduler.config.num_train_timesteps - - (denoising_start * components.scheduler.config.num_train_timesteps) - ) - ) - - num_inference_steps = (components.scheduler.timesteps < discrete_timestep_cutoff).sum().item() - if components.scheduler.order == 2 and num_inference_steps % 2 == 0: - # if the scheduler is a 2nd order scheduler we might have to do +1 - # because `num_inference_steps` might be even given that every timestep - # (except the highest one) is duplicated. If `num_inference_steps` is even it would - # mean that we cut the timesteps in the middle of the denoising step - # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 - # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler - num_inference_steps = num_inference_steps + 1 - - # because t_n+1 >= t_n, we slice the timesteps starting from the end - t_start = len(components.scheduler.timesteps) - num_inference_steps - timesteps = components.scheduler.timesteps[t_start:] - if hasattr(components.scheduler, "set_begin_index"): - components.scheduler.set_begin_index(t_start) - return timesteps, num_inference_steps - - - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - block_state.device = components._execution_device - - block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( - components.scheduler, block_state.num_inference_steps, block_state.device, block_state.timesteps, block_state.sigmas - ) - - def denoising_value_valid(dnv): - return isinstance(dnv, float) and 0 < dnv < 1 - - block_state.timesteps, block_state.num_inference_steps = self.get_timesteps( - components, - block_state.num_inference_steps, - block_state.strength, - block_state.device, - denoising_start=block_state.denoising_start if denoising_value_valid(block_state.denoising_start) else None, - ) - block_state.latent_timestep = block_state.timesteps[:1].repeat(block_state.batch_size * block_state.num_images_per_prompt) - - if block_state.denoising_end is not None and isinstance(block_state.denoising_end, float) and block_state.denoising_end > 0 and block_state.denoising_end < 1: - block_state.discrete_timestep_cutoff = int( - round( - components.scheduler.config.num_train_timesteps - - (block_state.denoising_end * components.scheduler.config.num_train_timesteps) - ) - ) - block_state.num_inference_steps = len(list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps))) - block_state.timesteps = block_state.timesteps[:block_state.num_inference_steps] - - self.add_block_state(state, block_state) - - return components, state - - -class StableDiffusionXLSetTimestepsStep(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("scheduler", EulerDiscreteScheduler), - ] - - @property - def description(self) -> str: - return ( - "Step that sets the scheduler's timesteps for inference" - ) - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam("num_inference_steps", default=50), - InputParam("timesteps"), - InputParam("sigmas"), - InputParam("denoising_end"), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), - OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time")] - - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - block_state.device = components._execution_device - - block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( - components.scheduler, block_state.num_inference_steps, block_state.device, block_state.timesteps, block_state.sigmas - ) - - if block_state.denoising_end is not None and isinstance(block_state.denoising_end, float) and block_state.denoising_end > 0 and block_state.denoising_end < 1: - block_state.discrete_timestep_cutoff = int( - round( - components.scheduler.config.num_train_timesteps - - (block_state.denoising_end * components.scheduler.config.num_train_timesteps) - ) - ) - block_state.num_inference_steps = len(list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps))) - block_state.timesteps = block_state.timesteps[:block_state.num_inference_steps] - - self.add_block_state(state, block_state) - return components, state - - -class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("scheduler", EulerDiscreteScheduler), - ] - - @property - def description(self) -> str: - return ( - "Step that prepares the latents for the inpainting process" - ) - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("generator"), - InputParam("latents"), - InputParam("num_images_per_prompt", default=1), - InputParam("denoising_start"), - InputParam( - "strength", - default=0.9999, - description="Conceptually, indicates how much to transform the reference `image` (the masked portion of image for inpainting). Must be between 0 and 1. `image` " - "will be used as a starting point, adding more noise to it the larger the `strength`. The number of " - "denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will " - "be maximum and the denoising process will run for the full number of iterations specified in " - "`num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of " - "`denoising_start` being declared as an integer, the value of `strength` will be ignored." - ), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), - InputParam( - "latent_timestep", - required=True, - type_hint=torch.Tensor, - description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step." - ), - InputParam( - "image_latents", - required=True, - type_hint=torch.Tensor, - description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step." - ), - InputParam( - "mask", - required=True, - type_hint=torch.Tensor, - description="The mask for the inpainting generation. Can be generated in vae_encode step." - ), - InputParam( - "masked_image_latents", - type_hint=torch.Tensor, - description="The masked image latents for the inpainting generation (only for inpainting-specific unet). Can be generated in vae_encode step." - ), - InputParam( - "dtype", - type_hint=torch.dtype, - description="The dtype of the model inputs" - ) - ] - - @property - def intermediates_outputs(self) -> List[str]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"), - OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for inpainting generation"), - OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting generation (only for inpainting-specific unet)"), - OutputParam("noise", type_hint=torch.Tensor, description="The noise added to the image latents, used for inpainting generation")] - - - # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components - # YiYi TODO: update the _encode_vae_image so that we can use #Coped from - @staticmethod - def _encode_vae_image(components, image: torch.Tensor, generator: torch.Generator): - - latents_mean = latents_std = None - if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: - latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: - latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) - - dtype = image.dtype - if components.vae.config.force_upcast: - image = image.float() - components.vae.to(dtype=torch.float32) - - if isinstance(generator, list): - image_latents = [ - retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(image.shape[0]) - ] - image_latents = torch.cat(image_latents, dim=0) - else: - image_latents = retrieve_latents(components.vae.encode(image), generator=generator) - - if components.vae.config.force_upcast: - components.vae.to(dtype) - - image_latents = image_latents.to(dtype) - if latents_mean is not None and latents_std is not None: - latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) - latents_std = latents_std.to(device=image_latents.device, dtype=dtype) - image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std - else: - image_latents = components.vae.config.scaling_factor * image_latents - - return image_latents - - # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents adding components as first argument - def prepare_latents_inpaint( - self, - components, - batch_size, - num_channels_latents, - height, - width, - dtype, - device, - generator, - latents=None, - image=None, - timestep=None, - is_strength_max=True, - add_noise=True, - return_noise=False, - return_image_latents=False, - ): - shape = ( - batch_size, - num_channels_latents, - int(height) // components.vae_scale_factor, - int(width) // components.vae_scale_factor, - ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if (image is None or timestep is None) and not is_strength_max: - raise ValueError( - "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." - "However, either the image or the noise timestep has not been provided." - ) - - if image.shape[1] == 4: - image_latents = image.to(device=device, dtype=dtype) - image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) - elif return_image_latents or (latents is None and not is_strength_max): - image = image.to(device=device, dtype=dtype) - image_latents = self._encode_vae_image(components, image=image, generator=generator) - image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) - - if latents is None and add_noise: - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - # if strength is 1. then initialise the latents to noise, else initial to image + noise - latents = noise if is_strength_max else components.scheduler.add_noise(image_latents, noise, timestep) - # if pure noise then scale the initial latents by the Scheduler's init sigma - latents = latents * components.scheduler.init_noise_sigma if is_strength_max else latents - elif add_noise: - noise = latents.to(device) - latents = noise * components.scheduler.init_noise_sigma - else: - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = image_latents.to(device) - - outputs = (latents,) - - if return_noise: - outputs += (noise,) - - if return_image_latents: - outputs += (image_latents,) - - return outputs - - # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents - # do not accept do_classifier_free_guidance - def prepare_mask_latents( - self, components, mask, masked_image, batch_size, height, width, dtype, device, generator - ): - # resize the mask to latents shape as we concatenate the mask to the latents - # we do that before converting to dtype to avoid breaking in case we're using cpu_offload - # and half precision - mask = torch.nn.functional.interpolate( - mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor) - ) - mask = mask.to(device=device, dtype=dtype) - - # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method - if mask.shape[0] < batch_size: - if not batch_size % mask.shape[0] == 0: - raise ValueError( - "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" - f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" - " of masks that you pass is divisible by the total requested batch size." - ) - mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) - - if masked_image is not None and masked_image.shape[1] == 4: - masked_image_latents = masked_image - else: - masked_image_latents = None - - if masked_image is not None: - if masked_image_latents is None: - masked_image = masked_image.to(device=device, dtype=dtype) - masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator) - - if masked_image_latents.shape[0] < batch_size: - if not batch_size % masked_image_latents.shape[0] == 0: - raise ValueError( - "The passed images and the required batch size don't match. Images are supposed to be duplicated" - f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." - " Make sure the number of images that you pass is divisible by the total requested batch size." - ) - masked_image_latents = masked_image_latents.repeat( - batch_size // masked_image_latents.shape[0], 1, 1, 1 - ) - - # aligning device to prevent device errors when concating it with the latent model input - masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) - - return mask, masked_image_latents - - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype - block_state.device = components._execution_device - - block_state.is_strength_max = block_state.strength == 1.0 - - # for non-inpainting specific unet, we do not need masked_image_latents - if hasattr(components,"unet") and components.unet is not None: - if components.unet.config.in_channels == 4: - block_state.masked_image_latents = None - - block_state.add_noise = True if block_state.denoising_start is None else False - - block_state.height = block_state.image_latents.shape[-2] * components.vae_scale_factor - block_state.width = block_state.image_latents.shape[-1] * components.vae_scale_factor - - block_state.latents, block_state.noise = self.prepare_latents_inpaint( - components, - block_state.batch_size * block_state.num_images_per_prompt, - components.num_channels_latents, - block_state.height, - block_state.width, - block_state.dtype, - block_state.device, - block_state.generator, - block_state.latents, - image=block_state.image_latents, - timestep=block_state.latent_timestep, - is_strength_max=block_state.is_strength_max, - add_noise=block_state.add_noise, - return_noise=True, - return_image_latents=False, - ) - - # 7. Prepare mask latent variables - block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents( - components, - block_state.mask, - block_state.masked_image_latents, - block_state.batch_size * block_state.num_images_per_prompt, - block_state.height, - block_state.width, - block_state.dtype, - block_state.device, - block_state.generator, - ) - - self.add_block_state(state, block_state) - - return components, state - - -class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("vae", AutoencoderKL), - ComponentSpec("scheduler", EulerDiscreteScheduler), - ] - - @property - def description(self) -> str: - return ( - "Step that prepares the latents for the image-to-image generation process" - ) - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("generator"), - InputParam("latents"), - InputParam("num_images_per_prompt", default=1), - InputParam("denoising_start"), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [ - InputParam("latent_timestep", required=True, type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step."), - InputParam("image_latents", required=True, type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step."), - InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."), - InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs")] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process")] - - # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents with self -> components - # YiYi TODO: refactor using _encode_vae_image - @staticmethod - def prepare_latents_img2img( - components, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True - ): - if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): - raise ValueError( - f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" - ) - - image = image.to(device=device, dtype=dtype) - - batch_size = batch_size * num_images_per_prompt - - if image.shape[1] == 4: - init_latents = image - - else: - latents_mean = latents_std = None - if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: - latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: - latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) - # make sure the VAE is in float32 mode, as it overflows in float16 - if components.vae.config.force_upcast: - image = image.float() - components.vae.to(dtype=torch.float32) - - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - elif isinstance(generator, list): - if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: - image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) - elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " - ) - - init_latents = [ - retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(batch_size) - ] - init_latents = torch.cat(init_latents, dim=0) - else: - init_latents = retrieve_latents(components.vae.encode(image), generator=generator) - - if components.vae.config.force_upcast: - components.vae.to(dtype) - - init_latents = init_latents.to(dtype) - if latents_mean is not None and latents_std is not None: - latents_mean = latents_mean.to(device=device, dtype=dtype) - latents_std = latents_std.to(device=device, dtype=dtype) - init_latents = (init_latents - latents_mean) * components.vae.config.scaling_factor / latents_std - else: - init_latents = components.vae.config.scaling_factor * init_latents - - if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: - # expand init_latents for batch_size - additional_image_per_prompt = batch_size // init_latents.shape[0] - init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) - elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." - ) - else: - init_latents = torch.cat([init_latents], dim=0) - - if add_noise: - shape = init_latents.shape - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - # get latents - init_latents = components.scheduler.add_noise(init_latents, noise, timestep) - - latents = init_latents - - return latents - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype - block_state.device = components._execution_device - block_state.add_noise = True if block_state.denoising_start is None else False - if block_state.latents is None: - block_state.latents = self.prepare_latents_img2img( - components, - block_state.image_latents, - block_state.latent_timestep, - block_state.batch_size, - block_state.num_images_per_prompt, - block_state.dtype, - block_state.device, - block_state.generator, - block_state.add_noise, - ) - - self.add_block_state(state, block_state) - - return components, state - - -class StableDiffusionXLPrepareLatentsStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("scheduler", EulerDiscreteScheduler), - ] - - @property - def description(self) -> str: - return ( - "Prepare latents step that prepares the latents for the text-to-image generation process" - ) - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam("height"), - InputParam("width"), - InputParam("generator"), - InputParam("latents"), - InputParam("num_images_per_prompt", default=1), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [ - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), - InputParam( - "dtype", - type_hint=torch.dtype, - description="The dtype of the model inputs" - ) - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [ - OutputParam( - "latents", - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process" - ) - ] - - - @staticmethod - def check_inputs(components, block_state): - if ( - block_state.height is not None - and block_state.height % components.vae_scale_factor != 0 - or block_state.width is not None - and block_state.width % components.vae_scale_factor != 0 - ): - raise ValueError( - f"`height` and `width` have to be divisible by {components.vae_scale_factor} but are {block_state.height} and {block_state.width}." - ) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with self -> components - @staticmethod - def prepare_latents(components, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = ( - batch_size, - num_channels_latents, - int(height) // components.vae_scale_factor, - int(width) // components.vae_scale_factor, - ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * components.scheduler.init_noise_sigma - return latents - - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - if block_state.dtype is None: - block_state.dtype = components.vae.dtype - - block_state.device = components._execution_device - - self.check_inputs(components, block_state) - - block_state.height = block_state.height or components.default_sample_size * components.vae_scale_factor - block_state.width = block_state.width or components.default_sample_size * components.vae_scale_factor - block_state.num_channels_latents = components.num_channels_latents - block_state.latents = self.prepare_latents( - components, - block_state.batch_size * block_state.num_images_per_prompt, - block_state.num_channels_latents, - block_state.height, - block_state.width, - block_state.dtype, - block_state.device, - block_state.generator, - block_state.latents, - ) - - self.add_block_state(state, block_state) - - return components, state - - -class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_configs(self) -> List[ConfigSpec]: - return [ConfigSpec("requires_aesthetics_score", False),] - - @property - def description(self) -> str: - return ( - "Step that prepares the additional conditioning for the image-to-image/inpainting generation process" - ) - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("original_size"), - InputParam("target_size"), - InputParam("negative_original_size"), - InputParam("negative_target_size"), - InputParam("crops_coords_top_left", default=(0, 0)), - InputParam("negative_crops_coords_top_left", default=(0, 0)), - InputParam("num_images_per_prompt", default=1), - InputParam("aesthetic_score", default=6.0), - InputParam("negative_aesthetic_score", default=2.0), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [ - InputParam("latents", required=True, type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."), - InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor, description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step."), - InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The time ids to condition the denoising process"), - OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process"), - OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids with self -> components - @staticmethod - def _get_add_time_ids_img2img( - components, - original_size, - crops_coords_top_left, - target_size, - aesthetic_score, - negative_aesthetic_score, - negative_original_size, - negative_crops_coords_top_left, - negative_target_size, - dtype, - text_encoder_projection_dim=None, - ): - if components.config.requires_aesthetics_score: - add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) - add_neg_time_ids = list( - negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) - ) - else: - add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) - - passed_add_embed_dim = ( - components.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim - ) - expected_add_embed_dim = components.unet.add_embedding.linear_1.in_features - - if ( - expected_add_embed_dim > passed_add_embed_dim - and (expected_add_embed_dim - passed_add_embed_dim) == components.unet.config.addition_time_embed_dim - ): - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." - ) - elif ( - expected_add_embed_dim < passed_add_embed_dim - and (passed_add_embed_dim - expected_add_embed_dim) == components.unet.config.addition_time_embed_dim - ): - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." - ) - elif expected_add_embed_dim != passed_add_embed_dim: - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." - ) - - add_time_ids = torch.tensor([add_time_ids], dtype=dtype) - add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) - - return add_time_ids, add_neg_time_ids - - # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding - @staticmethod - def get_guidance_scale_embedding( - w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 - ) -> torch.Tensor: - """ - See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 - - Args: - w (`torch.Tensor`): - Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. - embedding_dim (`int`, *optional*, defaults to 512): - Dimension of the embeddings to generate. - dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): - Data type of the generated embeddings. - - Returns: - `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. - """ - assert len(w.shape) == 1 - w = w * 1000.0 - - half_dim = embedding_dim // 2 - emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) - emb = w.to(dtype)[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0, 1)) - assert emb.shape == (w.shape[0], embedding_dim) - return emb - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - block_state.device = components._execution_device - - block_state.vae_scale_factor = components.vae_scale_factor - - block_state.height, block_state.width = block_state.latents.shape[-2:] - block_state.height = block_state.height * block_state.vae_scale_factor - block_state.width = block_state.width * block_state.vae_scale_factor - - block_state.original_size = block_state.original_size or (block_state.height, block_state.width) - block_state.target_size = block_state.target_size or (block_state.height, block_state.width) - - block_state.text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1]) - - if block_state.negative_original_size is None: - block_state.negative_original_size = block_state.original_size - if block_state.negative_target_size is None: - block_state.negative_target_size = block_state.target_size - - block_state.add_time_ids, block_state.negative_add_time_ids = self._get_add_time_ids_img2img( - components, - block_state.original_size, - block_state.crops_coords_top_left, - block_state.target_size, - block_state.aesthetic_score, - block_state.negative_aesthetic_score, - block_state.negative_original_size, - block_state.negative_crops_coords_top_left, - block_state.negative_target_size, - dtype=block_state.pooled_prompt_embeds.dtype, - text_encoder_projection_dim=block_state.text_encoder_projection_dim, - ) - block_state.add_time_ids = block_state.add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) - block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) - - # Optionally get Guidance Scale Embedding for LCM - block_state.timestep_cond = None - if ( - hasattr(components, "unet") - and components.unet is not None - and components.unet.config.time_cond_proj_dim is not None - ): - # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! - block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat(block_state.batch_size * block_state.num_images_per_prompt) - block_state.timestep_cond = self.get_guidance_scale_embedding( - block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim - ).to(device=block_state.device, dtype=block_state.latents.dtype) - - self.add_block_state(state, block_state) - return components, state - - -class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return ( - "Step that prepares the additional conditioning for the text-to-image generation process" - ) - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("original_size"), - InputParam("target_size"), - InputParam("negative_original_size"), - InputParam("negative_target_size"), - InputParam("crops_coords_top_left", default=(0, 0)), - InputParam("negative_crops_coords_top_left", default=(0, 0)), - InputParam("num_images_per_prompt", default=1), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [ - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." - ), - InputParam( - "pooled_prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step." - ), - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The time ids to condition the denoising process"), - OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process"), - OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids with self -> components - @staticmethod - def _get_add_time_ids( - components, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None - ): - add_time_ids = list(original_size + crops_coords_top_left + target_size) - - passed_add_embed_dim = ( - components.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim - ) - expected_add_embed_dim = components.unet.add_embedding.linear_1.in_features - - if expected_add_embed_dim != passed_add_embed_dim: - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." - ) - - add_time_ids = torch.tensor([add_time_ids], dtype=dtype) - return add_time_ids - - # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding - @staticmethod - def get_guidance_scale_embedding( - w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 - ) -> torch.Tensor: - """ - See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 - - Args: - w (`torch.Tensor`): - Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. - embedding_dim (`int`, *optional*, defaults to 512): - Dimension of the embeddings to generate. - dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): - Data type of the generated embeddings. - - Returns: - `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. - """ - assert len(w.shape) == 1 - w = w * 1000.0 - - half_dim = embedding_dim // 2 - emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) - emb = w.to(dtype)[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0, 1)) - assert emb.shape == (w.shape[0], embedding_dim) - return emb - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - block_state.device = components._execution_device - - block_state.height, block_state.width = block_state.latents.shape[-2:] - block_state.height = block_state.height * components.vae_scale_factor - block_state.width = block_state.width * components.vae_scale_factor - - block_state.original_size = block_state.original_size or (block_state.height, block_state.width) - block_state.target_size = block_state.target_size or (block_state.height, block_state.width) - - block_state.text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1]) - - block_state.add_time_ids = self._get_add_time_ids( - components, - block_state.original_size, - block_state.crops_coords_top_left, - block_state.target_size, - block_state.pooled_prompt_embeds.dtype, - text_encoder_projection_dim=block_state.text_encoder_projection_dim, - ) - if block_state.negative_original_size is not None and block_state.negative_target_size is not None: - block_state.negative_add_time_ids = self._get_add_time_ids( - components, - block_state.negative_original_size, - block_state.negative_crops_coords_top_left, - block_state.negative_target_size, - block_state.pooled_prompt_embeds.dtype, - text_encoder_projection_dim=block_state.text_encoder_projection_dim, - ) - else: - block_state.negative_add_time_ids = block_state.add_time_ids - - block_state.add_time_ids = block_state.add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) - block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) - - # Optionally get Guidance Scale Embedding for LCM - block_state.timestep_cond = None - if ( - hasattr(components, "unet") - and components.unet is not None - and components.unet.config.time_cond_proj_dim is not None - ): - # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! - block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat(block_state.batch_size * block_state.num_images_per_prompt) - block_state.timestep_cond = self.get_guidance_scale_embedding( - block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim - ).to(device=block_state.device, dtype=block_state.latents.dtype) - - self.add_block_state(state, block_state) - return components, state - -class StableDiffusionXLControlNetInputStep(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("controlnet", ControlNetModel), - ComponentSpec("control_image_processor", VaeImageProcessor, config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), default_creation_method="from_config"), - ] - - @property - def description(self) -> str: - return "step that prepare inputs for controlnet" - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("control_image", required=True), - InputParam("control_guidance_start", default=0.0), - InputParam("control_guidance_end", default=1.0), - InputParam("controlnet_conditioning_scale", default=1.0), - InputParam("guess_mode", default=False), - InputParam("num_images_per_prompt", default=1), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." - ), - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), - InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "crops_coords", - type_hint=Optional[Tuple[int]], - description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." - ), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [ - OutputParam("controlnet_cond", type_hint=torch.Tensor, description="The processed control image"), - OutputParam("control_guidance_start", type_hint=List[float], description="The controlnet guidance start values"), - OutputParam("control_guidance_end", type_hint=List[float], description="The controlnet guidance end values"), - OutputParam("conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values"), - OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"), - OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"), - ] - - - - # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image - # 1. return image without apply any guidance - # 2. add crops_coords and resize_mode to preprocess() - @staticmethod - def prepare_control_image( - components, - image, - width, - height, - batch_size, - num_images_per_prompt, - device, - dtype, - crops_coords=None, - ): - if crops_coords is not None: - image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) - else: - image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) - - image_batch_size = image.shape[0] - if image_batch_size == 1: - repeat_by = batch_size - else: - # image batch size is the same as prompt batch size - repeat_by = num_images_per_prompt - - image = image.repeat_interleave(repeat_by, dim=0) - image = image.to(device=device, dtype=dtype) - return image - - - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - - block_state = self.get_block_state(state) - - # (1) prepare controlnet inputs - block_state.device = components._execution_device - block_state.height, block_state.width = block_state.latents.shape[-2:] - block_state.height = block_state.height * components.vae_scale_factor - block_state.width = block_state.width * components.vae_scale_factor - - controlnet = unwrap_module(components.controlnet) - - # (1.1) - # control_guidance_start/control_guidance_end (align format) - if not isinstance(block_state.control_guidance_start, list) and isinstance(block_state.control_guidance_end, list): - block_state.control_guidance_start = len(block_state.control_guidance_end) * [block_state.control_guidance_start] - elif not isinstance(block_state.control_guidance_end, list) and isinstance(block_state.control_guidance_start, list): - block_state.control_guidance_end = len(block_state.control_guidance_start) * [block_state.control_guidance_end] - elif not isinstance(block_state.control_guidance_start, list) and not isinstance(block_state.control_guidance_end, list): - mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 - block_state.control_guidance_start, block_state.control_guidance_end = ( - mult * [block_state.control_guidance_start], - mult * [block_state.control_guidance_end], - ) - - # (1.2) - # controlnet_conditioning_scale (align format) - if isinstance(controlnet, MultiControlNetModel) and isinstance(block_state.controlnet_conditioning_scale, float): - block_state.controlnet_conditioning_scale = [block_state.controlnet_conditioning_scale] * len(controlnet.nets) - - # (1.3) - # global_pool_conditions - block_state.global_pool_conditions = ( - controlnet.config.global_pool_conditions - if isinstance(controlnet, ControlNetModel) - else controlnet.nets[0].config.global_pool_conditions - ) - # (1.4) - # guess_mode - block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions - - # (1.5) - # control_image - if isinstance(controlnet, ControlNetModel): - block_state.control_image = self.prepare_control_image( - components, - image=block_state.control_image, - width=block_state.width, - height=block_state.height, - batch_size=block_state.batch_size * block_state.num_images_per_prompt, - num_images_per_prompt=block_state.num_images_per_prompt, - device=block_state.device, - dtype=controlnet.dtype, - crops_coords=block_state.crops_coords, - ) - elif isinstance(controlnet, MultiControlNetModel): - control_images = [] - - for control_image_ in block_state.control_image: - control_image = self.prepare_control_image( - components, - image=control_image_, - width=block_state.width, - height=block_state.height, - batch_size=block_state.batch_size * block_state.num_images_per_prompt, - num_images_per_prompt=block_state.num_images_per_prompt, - device=block_state.device, - dtype=controlnet.dtype, - crops_coords=block_state.crops_coords, - ) - - control_images.append(control_image) - - block_state.control_image = control_images - else: - assert False - - # (1.6) - # controlnet_keep - block_state.controlnet_keep = [] - for i in range(len(block_state.timesteps)): - keeps = [ - 1.0 - float(i / len(block_state.timesteps) < s or (i + 1) / len(block_state.timesteps) > e) - for s, e in zip(block_state.control_guidance_start, block_state.control_guidance_end) - ] - block_state.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) - - block_state.controlnet_cond = block_state.control_image - block_state.conditioning_scale = block_state.controlnet_conditioning_scale - - - - self.add_block_state(state, block_state) - - return components, state - -class StableDiffusionXLControlNetUnionInputStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("controlnet", ControlNetUnionModel), - ComponentSpec("control_image_processor", VaeImageProcessor, config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), default_creation_method="from_config"), - ] - - @property - def description(self) -> str: - return "step that prepares inputs for the ControlNetUnion model" - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("control_image", required=True), - InputParam("control_mode", required=True), - InputParam("control_guidance_start", default=0.0), - InputParam("control_guidance_end", default=1.0), - InputParam("controlnet_conditioning_scale", default=1.0), - InputParam("guess_mode", default=False), - InputParam("num_images_per_prompt", default=1), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [ - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Used to determine the shape of the control images. Can be generated in prepare_latent step." - ), - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), - InputParam( - "dtype", - required=True, - type_hint=torch.dtype, - description="The dtype of model tensor inputs. Can be generated in input step." - ), - InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Needed to determine `controlnet_keep`. Can be generated in set_timesteps step." - ), - InputParam( - "crops_coords", - type_hint=Optional[Tuple[int]], - description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." - ), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [ - OutputParam("controlnet_cond", type_hint=List[torch.Tensor], description="The processed control images"), - OutputParam("control_type_idx", type_hint=List[int], description="The control mode indices", kwargs_type="controlnet_kwargs"), - OutputParam("control_type", type_hint=torch.Tensor, description="The control type tensor that specifies which control type is active", kwargs_type="controlnet_kwargs"), - OutputParam("control_guidance_start", type_hint=float, description="The controlnet guidance start value"), - OutputParam("control_guidance_end", type_hint=float, description="The controlnet guidance end value"), - OutputParam("conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values"), - OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"), - OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"), - ] - - # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image - # 1. return image without apply any guidance - # 2. add crops_coords and resize_mode to preprocess() - @staticmethod - def prepare_control_image( - components, - image, - width, - height, - batch_size, - num_images_per_prompt, - device, - dtype, - crops_coords=None, - ): - if crops_coords is not None: - image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) - else: - image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) - - image_batch_size = image.shape[0] - if image_batch_size == 1: - repeat_by = batch_size - else: - # image batch size is the same as prompt batch size - repeat_by = num_images_per_prompt - - image = image.repeat_interleave(repeat_by, dim=0) - image = image.to(device=device, dtype=dtype) - return image - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - - block_state = self.get_block_state(state) - - controlnet = unwrap_module(components.controlnet) - - device = components._execution_device - dtype = block_state.dtype or components.controlnet.dtype - - block_state.height, block_state.width = block_state.latents.shape[-2:] - block_state.height = block_state.height * components.vae_scale_factor - block_state.width = block_state.width * components.vae_scale_factor - - - # control_guidance_start/control_guidance_end (align format) - if not isinstance(block_state.control_guidance_start, list) and isinstance(block_state.control_guidance_end, list): - block_state.control_guidance_start = len(block_state.control_guidance_end) * [block_state.control_guidance_start] - elif not isinstance(block_state.control_guidance_end, list) and isinstance(block_state.control_guidance_start, list): - block_state.control_guidance_end = len(block_state.control_guidance_start) * [block_state.control_guidance_end] - - # guess_mode - block_state.global_pool_conditions = controlnet.config.global_pool_conditions - block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions - - # control_image - if not isinstance(block_state.control_image, list): - block_state.control_image = [block_state.control_image] - # control_mode - if not isinstance(block_state.control_mode, list): - block_state.control_mode = [block_state.control_mode] - - if len(block_state.control_image) != len(block_state.control_mode): - raise ValueError("Expected len(control_image) == len(control_type)") - - # control_type - block_state.num_control_type = controlnet.config.num_control_type - block_state.control_type = [0 for _ in range(block_state.num_control_type)] - for control_idx in block_state.control_mode: - block_state.control_type[control_idx] = 1 - block_state.control_type = torch.Tensor(block_state.control_type) - - block_state.control_type = block_state.control_type.reshape(1, -1).to(device, dtype=block_state.dtype) - repeat_by = block_state.batch_size * block_state.num_images_per_prompt // block_state.control_type.shape[0] - block_state.control_type = block_state.control_type.repeat_interleave(repeat_by, dim=0) - - # prepare control_image - for idx, _ in enumerate(block_state.control_image): - block_state.control_image[idx] = self.prepare_control_image( - components, - image=block_state.control_image[idx], - width=block_state.width, - height=block_state.height, - batch_size=block_state.batch_size * block_state.num_images_per_prompt, - num_images_per_prompt=block_state.num_images_per_prompt, - device=device, - dtype=dtype, - crops_coords=block_state.crops_coords, - ) - block_state.height, block_state.width = block_state.control_image[idx].shape[-2:] - - # controlnet_keep - block_state.controlnet_keep = [] - for i in range(len(block_state.timesteps)): - block_state.controlnet_keep.append( - 1.0 - - float(i / len(block_state.timesteps) < block_state.control_guidance_start or (i + 1) / len(block_state.timesteps) > block_state.control_guidance_end) - ) - block_state.control_type_idx = block_state.control_mode - block_state.controlnet_cond = block_state.control_image - block_state.conditioning_scale = block_state.controlnet_conditioning_scale - - self.add_block_state(state, block_state) - - return components, state - - -class StableDiffusionXLControlNetAutoInput(AutoPipelineBlocks): - - block_classes = [StableDiffusionXLControlNetUnionInputStep, StableDiffusionXLControlNetInputStep] - block_names = ["controlnet_union", "controlnet"] - block_trigger_inputs = ["control_mode", "control_image"] - - -class StableDiffusionXLDecodeLatentsStep(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("vae", AutoencoderKL), - ComponentSpec( - "image_processor", - VaeImageProcessor, - config=FrozenDict({"vae_scale_factor": 8}), - default_creation_method="from_config"), - ] - - @property - def description(self) -> str: - return "Step that decodes the denoised latents into images" - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("output_type", default="pil"), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [InputParam("latents", required=True, type_hint=torch.Tensor, description="The denoised latents from the denoising step")] - - @property - def intermediates_outputs(self) -> List[str]: - return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array")] - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae with self -> components - @staticmethod - def upcast_vae(components): - dtype = components.vae.dtype - components.vae.to(dtype=torch.float32) - use_torch_2_0_or_xformers = isinstance( - components.vae.decoder.mid_block.attentions[0].processor, - ( - AttnProcessor2_0, - XFormersAttnProcessor, - ), - ) - # if xformers or torch_2_0 is used attention block does not need - # to be in float32 which can save lots of memory - if use_torch_2_0_or_xformers: - components.vae.post_quant_conv.to(dtype) - components.vae.decoder.conv_in.to(dtype) - components.vae.decoder.mid_block.to(dtype) - - @torch.no_grad() - def __call__(self, components, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - if not block_state.output_type == "latent": - # make sure the VAE is in float32 mode, as it overflows in float16 - block_state.needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast - - if block_state.needs_upcasting: - self.upcast_vae(components) - block_state.latents = block_state.latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype) - elif block_state.latents.dtype != components.vae.dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - components.vae = components.vae.to(block_state.latents.dtype) - - # unscale/denormalize the latents - # denormalize with the mean and std if available and not None - block_state.has_latents_mean = ( - hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None - ) - block_state.has_latents_std = ( - hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None - ) - if block_state.has_latents_mean and block_state.has_latents_std: - block_state.latents_mean = ( - torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1).to(block_state.latents.device, block_state.latents.dtype) - ) - block_state.latents_std = ( - torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(block_state.latents.device, block_state.latents.dtype) - ) - block_state.latents = block_state.latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean - else: - block_state.latents = block_state.latents / components.vae.config.scaling_factor - - block_state.images = components.vae.decode(block_state.latents, return_dict=False)[0] - - # cast back to fp16 if needed - if block_state.needs_upcasting: - components.vae.to(dtype=torch.float16) - else: - block_state.images = block_state.latents - - # apply watermark if available - if hasattr(components, "watermark") and components.watermark is not None: - block_state.images = components.watermark.apply_watermark(block_state.images) - - block_state.images = components.image_processor.postprocess(block_state.images, output_type=block_state.output_type) - - self.add_block_state(state, block_state) - - return components, state - - -class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return "A post-processing step that overlays the mask on the image (inpainting task only).\n" + \ - "only needed when you are using the `padding_mask_crop` option when pre-processing the image and mask" - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("image", required=True), - InputParam("mask_image", required=True), - InputParam("padding_mask_crop"), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam("images", required=True, type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images from the decode step"), - InputParam("crops_coords", required=True, type_hint=Tuple[int, int], description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.") - ] - - @property - def intermediates_outputs(self) -> List[str]: - return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images with the mask overlayed")] - - @torch.no_grad() - def __call__(self, components, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - if block_state.padding_mask_crop is not None and block_state.crops_coords is not None: - block_state.images = [components.image_processor.apply_overlay(block_state.mask_image, block_state.image, i, block_state.crops_coords) for i in block_state.images] - - self.add_block_state(state, block_state) - - return components, state - - -class StableDiffusionXLOutputStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return "final step to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or a plain tuple." - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [InputParam("return_dict", default=True)] - - @property - def intermediates_inputs(self) -> List[str]: - return [InputParam("images", required=True, type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images from the decode step.")] - - @property - def intermediates_outputs(self) -> List[str]: - return [OutputParam("images", description="The final images output, can be a tuple or a `StableDiffusionXLPipelineOutput`")] - - - @torch.no_grad() - def __call__(self, components, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - if not block_state.return_dict: - block_state.images = (block_state.images,) - else: - block_state.images = StableDiffusionXLPipelineOutput(images=block_state.images) - self.add_block_state(state, block_state) - return components, state - - -# Encode -class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLInpaintVaeEncoderStep, StableDiffusionXLVaeEncoderStep] - block_names = ["inpaint", "img2img"] - block_trigger_inputs = ["mask_image", "image"] - - @property - def description(self): - return "Vae encoder step that encode the image inputs into their latent representations.\n" + \ - "This is an auto pipeline block that works for both inpainting and img2img tasks.\n" + \ - " - `StableDiffusionXLInpaintVaeEncoderStep` (inpaint) is used when both `mask_image` and `image` are provided.\n" + \ - " - `StableDiffusionXLVaeEncoderStep` (img2img) is used when only `image` is provided." - - -# Before denoise -class StableDiffusionXLBeforeDenoiseStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLInputStep, StableDiffusionXLSetTimestepsStep, StableDiffusionXLPrepareLatentsStep, StableDiffusionXLPrepareAdditionalConditioningStep, StableDiffusionXLControlNetAutoInput] - block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond", "controlnet_input"] - - @property - def description(self): - return "Before denoise step that prepare the inputs for the denoise step.\n" + \ - "This is a sequential pipeline blocks:\n" + \ - " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ - " - `StableDiffusionXLSetTimestepsStep` is used to set the timesteps\n" + \ - " - `StableDiffusionXLPrepareLatentsStep` is used to prepare the latents\n" + \ - " - `StableDiffusionXLPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + \ - " - `StableDiffusionXLControlNetAutoInput` is used to prepare the controlnet input" - - -class StableDiffusionXLImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLInputStep, StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLImg2ImgPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, StableDiffusionXLControlNetAutoInput] - block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond", "controlnet_input"] - - @property - def description(self): - return "Before denoise step that prepare the inputs for the denoise step for img2img task.\n" + \ - "This is a sequential pipeline blocks:\n" + \ - " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ - " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" + \ - " - `StableDiffusionXLImg2ImgPrepareLatentsStep` is used to prepare the latents\n" + \ - " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + \ - " - `StableDiffusionXLControlNetAutoInput` is used to prepare the controlnet input" - - -class StableDiffusionXLInpaintBeforeDenoiseStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLInputStep, StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLInpaintPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, StableDiffusionXLControlNetAutoInput] - block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond", "controlnet_input"] - - @property - def description(self): - return "Before denoise step that prepare the inputs for the denoise step for inpainting task.\n" + \ - "This is a sequential pipeline blocks:\n" + \ - " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ - " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" + \ - " - `StableDiffusionXLInpaintPrepareLatentsStep` is used to prepare the latents\n" + \ - " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + \ - " - `StableDiffusionXLControlNetAutoInput` is used to prepare the controlnet input" - - -class StableDiffusionXLAutoBeforeDenoiseStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLInpaintBeforeDenoiseStep, StableDiffusionXLImg2ImgBeforeDenoiseStep, StableDiffusionXLBeforeDenoiseStep] - block_names = ["inpaint", "img2img", "text2img"] - block_trigger_inputs = ["mask", "image_latents", None] - - @property - def description(self): - return "Before denoise step that prepare the inputs for the denoise step.\n" + \ - "This is an auto pipeline block that works for text2img, img2img and inpainting tasks as well as controlnet, controlnet_union.\n" + \ - " - `StableDiffusionXLInpaintBeforeDenoiseStep` (inpaint) is used when both `mask` and `image_latents` are provided.\n" + \ - " - `StableDiffusionXLImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n" + \ - " - `StableDiffusionXLBeforeDenoiseStep` (text2img) is used when both `image_latents` and `mask` are not provided.\n" + \ - " - `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided.\n" + \ - " - `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided." - -# # Denoise -from .pipeline_stable_diffusion_xl_modular_denoise_loop import StableDiffusionXLDenoiseStep, StableDiffusionXLControlNetDenoiseStep, StableDiffusionXLAutoDenoiseStep -# class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): -# block_classes = [StableDiffusionXLControlNetUnionStep, StableDiffusionXLControlNetStep, StableDiffusionXLDenoiseStep] -# block_names = ["controlnet_union", "controlnet", "unet"] -# block_trigger_inputs = ["control_mode", "control_image", None] - -# @property -# def description(self): -# return "Denoise step that denoise the latents.\n" + \ -# "This is an auto pipeline block that works for controlnet, controlnet_union and no controlnet.\n" + \ -# " - `StableDiffusionXLControlNetUnionStep` (controlnet_union) is used when both `control_mode` and `control_image` are provided.\n" + \ -# " - `StableDiffusionXLControlNetStep` (controlnet) is used when `control_image` is provided.\n" + \ -# " - `StableDiffusionXLDenoiseStep` (unet only) is used when both `control_mode` and `control_image` are not provided." - -# After denoise -class StableDiffusionXLDecodeStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLDecodeLatentsStep, StableDiffusionXLOutputStep] - block_names = ["decode", "output"] - - @property - def description(self): - return """Decode step that decode the denoised latents into images outputs. -This is a sequential pipeline blocks: - - `StableDiffusionXLDecodeLatentsStep` is used to decode the denoised latents into images - - `StableDiffusionXLOutputStep` is used to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or a plain tuple.""" - - -class StableDiffusionXLInpaintDecodeStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLDecodeLatentsStep, StableDiffusionXLInpaintOverlayMaskStep, StableDiffusionXLOutputStep] - block_names = ["decode", "mask_overlay", "output"] - - @property - def description(self): - return "Inpaint decode step that decode the denoised latents into images outputs.\n" + \ - "This is a sequential pipeline blocks:\n" + \ - " - `StableDiffusionXLDecodeLatentsStep` is used to decode the denoised latents into images\n" + \ - " - `StableDiffusionXLInpaintOverlayMaskStep` is used to overlay the mask on the image\n" + \ - " - `StableDiffusionXLOutputStep` is used to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or a plain tuple." - - -class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLInpaintDecodeStep, StableDiffusionXLDecodeStep] - block_names = ["inpaint", "non-inpaint"] - block_trigger_inputs = ["padding_mask_crop", None] - - @property - def description(self): - return "Decode step that decode the denoised latents into images outputs.\n" + \ - "This is an auto pipeline block that works for inpainting and non-inpainting tasks.\n" + \ - " - `StableDiffusionXLInpaintDecodeStep` (inpaint) is used when `padding_mask_crop` is provided.\n" + \ - " - `StableDiffusionXLDecodeStep` (non-inpaint) is used when `padding_mask_crop` is not provided." - - -class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks, ModularIPAdapterMixin): - block_classes = [StableDiffusionXLIPAdapterStep] - block_names = ["ip_adapter"] - block_trigger_inputs = ["ip_adapter_image"] - - @property - def description(self): - return "Run IP Adapter step if `ip_adapter_image` is provided." - - -class StableDiffusionXLAutoPipeline(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep, StableDiffusionXLAutoBeforeDenoiseStep, StableDiffusionXLAutoDenoiseStep, StableDiffusionXLAutoDecodeStep] - block_names = ["text_encoder", "ip_adapter", "image_encoder", "before_denoise", "denoise", "decode"] - - @property - def description(self): - return "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion XL.\n" + \ - "- for image-to-image generation, you need to provide either `image` or `image_latents`\n" + \ - "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n" + \ - "- to run the controlnet workflow, you need to provide `control_image`\n" + \ - "- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n" + \ - "- to run the ip_adapter workflow, you need to provide `ip_adapter_image`\n" + \ - "- for text-to-image generation, all you need to provide is `prompt`" - -# TODO(yiyi, aryan): We need another step before text encoder to set the `num_inference_steps` attribute for guider so that -# things like when to do guidance and how many conditions to be prepared can be determined. Currently, this is done by -# always assuming you want to do guidance in the Guiders. So, negative embeddings are prepared regardless of what the -# configuration of guider is. - - -# block mapping -TEXT2IMAGE_BLOCKS = OrderedDict([ - ("text_encoder", StableDiffusionXLTextEncoderStep), - ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), - ("input", StableDiffusionXLInputStep), - ("set_timesteps", StableDiffusionXLSetTimestepsStep), - ("prepare_latents", StableDiffusionXLPrepareLatentsStep), - ("prepare_add_cond", StableDiffusionXLPrepareAdditionalConditioningStep), - ("denoise", StableDiffusionXLDenoiseStep), - ("decode", StableDiffusionXLDecodeStep) -]) - -IMAGE2IMAGE_BLOCKS = OrderedDict([ - ("text_encoder", StableDiffusionXLTextEncoderStep), - ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), - ("image_encoder", StableDiffusionXLVaeEncoderStep), - ("input", StableDiffusionXLInputStep), - ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), - ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep), - ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), - ("denoise", StableDiffusionXLDenoiseStep), - ("decode", StableDiffusionXLDecodeStep) -]) - -INPAINT_BLOCKS = OrderedDict([ - ("text_encoder", StableDiffusionXLTextEncoderStep), - ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), - ("image_encoder", StableDiffusionXLInpaintVaeEncoderStep), - ("input", StableDiffusionXLInputStep), - ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), - ("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep), - ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), - ("denoise", StableDiffusionXLDenoiseStep), - ("decode", StableDiffusionXLInpaintDecodeStep) -]) - -CONTROLNET_BLOCKS = OrderedDict([ - ("controlnet_input", StableDiffusionXLControlNetInputStep), - ("denoise", StableDiffusionXLControlNetDenoiseStep), -]) - -CONTROLNET_UNION_BLOCKS = OrderedDict([ - ("controlnet_input", StableDiffusionXLControlNetUnionInputStep), - ("denoise", StableDiffusionXLControlNetDenoiseStep), -]) - -IP_ADAPTER_BLOCKS = OrderedDict([ - ("ip_adapter", StableDiffusionXLIPAdapterStep), -]) - -AUTO_BLOCKS = OrderedDict([ - ("text_encoder", StableDiffusionXLTextEncoderStep), - ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), - ("image_encoder", StableDiffusionXLAutoVaeEncoderStep), - ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), - ("denoise", StableDiffusionXLAutoDenoiseStep), - ("decode", StableDiffusionXLAutoDecodeStep) -]) - -AUTO_CORE_BLOCKS = OrderedDict([ - ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), - ("denoise", StableDiffusionXLAutoDenoiseStep), -]) - - -SDXL_SUPPORTED_BLOCKS = { - "text2img": TEXT2IMAGE_BLOCKS, - "img2img": IMAGE2IMAGE_BLOCKS, - "inpaint": INPAINT_BLOCKS, - "controlnet": CONTROLNET_BLOCKS, - "controlnet_union": CONTROLNET_UNION_BLOCKS, - "ip_adapter": IP_ADAPTER_BLOCKS, - "auto": AUTO_BLOCKS -} - - - -# YiYi Notes: not used yet, maintain a list of schema that can be used across all pipeline blocks -SDXL_INPUTS_SCHEMA = { - "prompt": InputParam("prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"), - "prompt_2": InputParam("prompt_2", type_hint=Union[str, List[str]], description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2"), - "negative_prompt": InputParam("negative_prompt", type_hint=Union[str, List[str]], description="The prompt or prompts not to guide the image generation"), - "negative_prompt_2": InputParam("negative_prompt_2", type_hint=Union[str, List[str]], description="The negative prompt or prompts for text_encoder_2"), - "cross_attention_kwargs": InputParam("cross_attention_kwargs", type_hint=Optional[dict], description="Kwargs dictionary passed to the AttentionProcessor"), - "clip_skip": InputParam("clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"), - "image": InputParam("image", type_hint=PipelineImageInput, required=True, description="The image(s) to modify for img2img or inpainting"), - "mask_image": InputParam("mask_image", type_hint=PipelineImageInput, required=True, description="Mask image for inpainting, white pixels will be repainted"), - "generator": InputParam("generator", type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], description="Generator(s) for deterministic generation"), - "height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"), - "width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"), - "num_images_per_prompt": InputParam("num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"), - "num_inference_steps": InputParam("num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"), - "timesteps": InputParam("timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"), - "sigmas": InputParam("sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"), - "denoising_end": InputParam("denoising_end", type_hint=Optional[float], description="Fraction of denoising process to complete before termination"), - # YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999 - "strength": InputParam("strength", type_hint=float, default=0.3, description="How much to transform the reference image"), - "denoising_start": InputParam("denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"), - "latents": InputParam("latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"), - "padding_mask_crop": InputParam("padding_mask_crop", type_hint=Optional[Tuple[int, int]], description="Size of margin in crop for image and mask"), - "original_size": InputParam("original_size", type_hint=Optional[Tuple[int, int]], description="Original size of the image for SDXL's micro-conditioning"), - "target_size": InputParam("target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"), - "negative_original_size": InputParam("negative_original_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on image resolution"), - "negative_target_size": InputParam("negative_target_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on target resolution"), - "crops_coords_top_left": InputParam("crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Top-left coordinates for SDXL's micro-conditioning"), - "negative_crops_coords_top_left": InputParam("negative_crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Negative conditioning crop coordinates"), - "aesthetic_score": InputParam("aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"), - "negative_aesthetic_score": InputParam("negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"), - "eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"), - "output_type": InputParam("output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"), - "return_dict": InputParam("return_dict", type_hint=bool, default=True, description="Whether to return a StableDiffusionXLPipelineOutput"), - "ip_adapter_image": InputParam("ip_adapter_image", type_hint=PipelineImageInput, required=True, description="Image(s) to be used as IP adapter"), - "control_image": InputParam("control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"), - "control_guidance_start": InputParam("control_guidance_start", type_hint=Union[float, List[float]], default=0.0, description="When ControlNet starts applying"), - "control_guidance_end": InputParam("control_guidance_end", type_hint=Union[float, List[float]], default=1.0, description="When ControlNet stops applying"), - "controlnet_conditioning_scale": InputParam("controlnet_conditioning_scale", type_hint=Union[float, List[float]], default=1.0, description="Scale factor for ControlNet outputs"), - "guess_mode": InputParam("guess_mode", type_hint=bool, default=False, description="Enables ControlNet encoder to recognize input without prompts"), - "control_mode": InputParam("control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet") -} - - -SDXL_INTERMEDIATE_INPUTS_SCHEMA = { - "prompt_embeds": InputParam("prompt_embeds", type_hint=torch.Tensor, required=True, description="Text embeddings used to guide image generation"), - "negative_prompt_embeds": InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"), - "pooled_prompt_embeds": InputParam("pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"), - "negative_pooled_prompt_embeds": InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"), - "batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"), - "dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), - "preprocess_kwargs": InputParam("preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"), - "latents": InputParam("latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"), - "timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"), - "num_inference_steps": InputParam("num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"), - "latent_timestep": InputParam("latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"), - "image_latents": InputParam("image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"), - "mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"), - "masked_image_latents": InputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"), - "add_time_ids": InputParam("add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"), - "negative_add_time_ids": InputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"), - "timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"), - "noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"), - "crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"), - "ip_adapter_embeds": InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"), - "negative_ip_adapter_embeds": InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"), - "images": InputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], required=True, description="Generated images") -} - - -SDXL_INTERMEDIATE_OUTPUTS_SCHEMA = { - "prompt_embeds": OutputParam("prompt_embeds", type_hint=torch.Tensor, description="Text embeddings used to guide image generation"), - "negative_prompt_embeds": OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"), - "pooled_prompt_embeds": OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, description="Pooled text embeddings"), - "negative_pooled_prompt_embeds": OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"), - "batch_size": OutputParam("batch_size", type_hint=int, description="Number of prompts"), - "dtype": OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), - "image_latents": OutputParam("image_latents", type_hint=torch.Tensor, description="Latents representing reference image"), - "mask": OutputParam("mask", type_hint=torch.Tensor, description="Mask for inpainting"), - "masked_image_latents": OutputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"), - "crops_coords": OutputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"), - "timesteps": OutputParam("timesteps", type_hint=torch.Tensor, description="Timesteps for inference"), - "num_inference_steps": OutputParam("num_inference_steps", type_hint=int, description="Number of denoising steps"), - "latent_timestep": OutputParam("latent_timestep", type_hint=torch.Tensor, description="Initial noise level timestep"), - "add_time_ids": OutputParam("add_time_ids", type_hint=torch.Tensor, description="Time ids for conditioning"), - "negative_add_time_ids": OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"), - "timestep_cond": OutputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"), - "latents": OutputParam("latents", type_hint=torch.Tensor, description="Denoised latents"), - "noise": OutputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"), - "ip_adapter_embeds": OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"), - "negative_ip_adapter_embeds": OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"), - "images": OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="Generated images") -} - - -SDXL_OUTPUTS_SCHEMA = { - "images": OutputParam("images", type_hint=Union[Tuple[Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]]], StableDiffusionXLPipelineOutput], description="The final generated images") -} diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular_denoise_loop.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular_denoise_loop.py deleted file mode 100644 index 63d0784a5762..000000000000 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular_denoise_loop.py +++ /dev/null @@ -1,1363 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch -from tqdm.auto import tqdm - -from ...configuration_utils import FrozenDict -from ...models import ControlNetModel, UNet2DConditionModel -from ...schedulers import EulerDiscreteScheduler -from ...utils import logging -from ...utils.torch_utils import unwrap_module -from ..modular_pipeline import ( - PipelineBlock, - PipelineState, - AutoPipelineBlocks, - LoopSequentialPipelineBlocks, - InputParam, - OutputParam, - BlockState, - ComponentSpec, -) -from ...guiders import ClassifierFreeGuidance -from .pipeline_stable_diffusion_xl_modular import StableDiffusionXLModularLoader -from dataclasses import asdict - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - - -# YiYi experimenting composible denoise loop -# loop step (1): prepare latent input for denoiser -class StableDiffusionXLDenoiseLoopBeforeDenoiser(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("scheduler", EulerDiscreteScheduler), - ] - - @property - def description(self) -> str: - return "step within the denoising loop that prepare the latent input for the denoiser" - - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." - ), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("scaled_latents", type_hint=torch.Tensor, description="The scaled latents input for denoiser")] - - - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): - - block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) - - - return components, block_state - -# loop step (1): prepare latent input for denoiser (with inpainting) -class StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("scheduler", EulerDiscreteScheduler), - ComponentSpec("unet", UNet2DConditionModel), - ] - - @property - def description(self) -> str: - return "step within the denoising loop that prepare the latent input for the denoiser" - - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." - ), - InputParam( - "mask", - type_hint=Optional[torch.Tensor], - description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "masked_image_latents", - type_hint=Optional[torch.Tensor], - description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("scaled_latents", type_hint=torch.Tensor, description="The scaled latents input for denoiser")] - - @staticmethod - def check_inputs(components, block_state): - - num_channels_unet = components.num_channels_unet - if num_channels_unet == 9: - # default case for runwayml/stable-diffusion-inpainting - if block_state.mask is None or block_state.masked_image_latents is None: - raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") - num_channels_latents = block_state.latents.shape[1] - num_channels_mask = block_state.mask.shape[1] - num_channels_masked_image = block_state.masked_image_latents.shape[1] - if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: - raise ValueError( - f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" - f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" - f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" - " `components.unet` or your `mask_image` or `image` input." - ) - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): - - self.check_inputs(components, block_state) - - block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) - if components.num_channels_unet == 9: - block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) - - - return components, block_state - -# loop step (2): denoise the latents with guidance -class StableDiffusionXLDenoiseLoopDenoiser(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config"), - ComponentSpec("unet", UNet2DConditionModel), - ] - - @property - def description(self) -> str: - return ( - "Step within the denoising loop that denoise the latents with guidance" - ) - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("cross_attention_kwargs"), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "scaled_latents", - required=True, - type_hint=torch.Tensor, - description="The prepared latents input to use for the denoiser. Can be generated in latent step within the denoise loop." - ), - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "timestep_cond", - type_hint=Optional[torch.Tensor], - description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." - ), - InputParam( - kwargs_type="guider_input_fields", - description=( - "All conditional model inputs that need to be prepared with guider. " - "It should contain prompt_embeds/negative_prompt_embeds, " - "add_time_ids/negative_add_time_ids, " - "pooled_prompt_embeds/negative_pooled_prompt_embeds, " - "and ip_adapter_embeds/negative_ip_adapter_embeds (optional)." - "please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" - ) - ), - - ] - - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int) -> PipelineState: - - # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) - # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) - guider_input_fields ={ - "prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"), - "time_ids": ("add_time_ids", "negative_add_time_ids"), - "text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), - "image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"), - } - - - components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) - - # Prepare mini‐batches according to guidance method and `guider_input_fields` - # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds. - # e.g. for CFG, we prepare two batches: one for uncond, one for cond - # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds - # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds - guider_state = components.guider.prepare_inputs(block_state, guider_input_fields) - - # run the denoiser for each guidance batch - for guider_state_batch in guider_state: - components.guider.prepare_models(components.unet) - cond_kwargs = guider_state_batch.as_dict() - cond_kwargs = {k:v for k,v in cond_kwargs.items() if k in guider_input_fields} - prompt_embeds = cond_kwargs.pop("prompt_embeds") - - # Predict the noise residual - # store the noise_pred in guider_state_batch so that we can apply guidance across all batches - guider_state_batch.noise_pred = components.unet( - block_state.scaled_latents, - t, - encoder_hidden_states=prompt_embeds, - timestep_cond=block_state.timestep_cond, - cross_attention_kwargs=block_state.cross_attention_kwargs, - added_cond_kwargs=cond_kwargs, - return_dict=False, - )[0] - components.guider.cleanup_models(components.unet) - - # Perform guidance - block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state) - - return components, block_state - -# loop step (2): denoise the latents with guidance (with controlnet) -class StableDiffusionXLControlNetDenoiseLoopDenoiser(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config"), - ComponentSpec("unet", UNet2DConditionModel), - ComponentSpec("controlnet", ControlNetModel), - ] - - @property - def description(self) -> str: - return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("cross_attention_kwargs"), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "controlnet_cond", - required=True, - type_hint=torch.Tensor, - description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step." - ), - InputParam( - "conditioning_scale", - type_hint=float, - description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." - ), - InputParam( - "guess_mode", - required=True, - type_hint=bool, - description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." - ), - InputParam( - "controlnet_keep", - required=True, - type_hint=List[float], - description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step." - ), - InputParam( - "scaled_latents", - required=True, - type_hint=torch.Tensor, - description="The prepared latents input to use for the denoiser. Can be generated in latent step within the denoise loop." - ), - InputParam( - "timestep_cond", - type_hint=Optional[torch.Tensor], - description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" - ), - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - kwargs_type="guider_input_fields", - description=( - "All conditional model inputs that need to be prepared with guider. " - "It should contain prompt_embeds/negative_prompt_embeds, " - "add_time_ids/negative_add_time_ids, " - "pooled_prompt_embeds/negative_pooled_prompt_embeds, " - "and ip_adapter_embeds/negative_ip_adapter_embeds (optional)." - "please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" - ) - ), - InputParam( - kwargs_type="controlnet_kwargs", - description=( - "additional kwargs for controlnet (e.g. control_type_idx and control_type from the controlnet union input step )" - "please add `kwargs_type=controlnet_kwargs` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" - ) - ) - ] - - @staticmethod - def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): - - accepted_kwargs = set(inspect.signature(func).parameters.keys()) - extra_kwargs = {} - for key, value in kwargs.items(): - if key in accepted_kwargs and key not in exclude_kwargs: - extra_kwargs[key] = value - - return extra_kwargs - - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): - - extra_controlnet_kwargs = self.prepare_extra_kwargs(components.controlnet.forward, **block_state.controlnet_kwargs) - - # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) - # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) - guider_input_fields ={ - "prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"), - "time_ids": ("add_time_ids", "negative_add_time_ids"), - "text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), - "image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"), - } - - - # cond_scale for the timestep (controlnet input) - if isinstance(block_state.controlnet_keep[i], list): - block_state.cond_scale = [c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i])] - else: - controlnet_cond_scale = block_state.conditioning_scale - if isinstance(controlnet_cond_scale, list): - controlnet_cond_scale = controlnet_cond_scale[0] - block_state.cond_scale = controlnet_cond_scale * block_state.controlnet_keep[i] - - # default controlnet output/unet input for guess mode + conditional path - block_state.down_block_res_samples_zeros = None - block_state.mid_block_res_sample_zeros = None - - # guided denoiser step - components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) - - # Prepare mini‐batches according to guidance method and `guider_input_fields` - # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds. - # e.g. for CFG, we prepare two batches: one for uncond, one for cond - # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds - # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds - guider_state = components.guider.prepare_inputs(block_state, guider_input_fields) - - # run the denoiser for each guidance batch - for guider_state_batch in guider_state: - components.guider.prepare_models(components.unet) - - # Prepare additional conditionings - added_cond_kwargs = { - "text_embeds": guider_state_batch.text_embeds, - "time_ids": guider_state_batch.time_ids, - } - if hasattr(guider_state_batch, "image_embeds") and guider_state_batch.image_embeds is not None: - added_cond_kwargs["image_embeds"] = guider_state_batch.image_embeds - - # Prepare controlnet additional conditionings - controlnet_added_cond_kwargs = { - "text_embeds": guider_state_batch.text_embeds, - "time_ids": guider_state_batch.time_ids, - } - # run controlnet for the guidance batch - if block_state.guess_mode and not components.guider.is_conditional: - # guider always run uncond batch first, so these tensors should be set already - down_block_res_samples = block_state.down_block_res_samples_zeros - mid_block_res_sample = block_state.mid_block_res_sample_zeros - else: - down_block_res_samples, mid_block_res_sample = components.controlnet( - block_state.scaled_latents, - t, - encoder_hidden_states=guider_state_batch.prompt_embeds, - controlnet_cond=block_state.controlnet_cond, - conditioning_scale=block_state.cond_scale, - guess_mode=block_state.guess_mode, - added_cond_kwargs=controlnet_added_cond_kwargs, - return_dict=False, - **extra_controlnet_kwargs, - ) - - # assign it to block_state so it will be available for the uncond guidance batch - if block_state.down_block_res_samples_zeros is None: - block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in down_block_res_samples] - if block_state.mid_block_res_sample_zeros is None: - block_state.mid_block_res_sample_zeros = torch.zeros_like(mid_block_res_sample) - - # Predict the noise - # store the noise_pred in guider_state_batch so we can apply guidance across all batches - guider_state_batch.noise_pred = components.unet( - block_state.scaled_latents, - t, - encoder_hidden_states=guider_state_batch.prompt_embeds, - timestep_cond=block_state.timestep_cond, - cross_attention_kwargs=block_state.cross_attention_kwargs, - added_cond_kwargs=added_cond_kwargs, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, - return_dict=False, - )[0] - components.guider.cleanup_models(components.unet) - - # Perform guidance - block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state) - - return components, block_state - -# loop step (3): scheduler step to update latents -class StableDiffusionXLDenoiseLoopAfterDenoiser(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("scheduler", EulerDiscreteScheduler), - ] - - @property - def description(self) -> str: - return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("generator"), - InputParam("eta", default=0.0), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." - ), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - - #YiYi TODO: move this out of here - @staticmethod - def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): - - accepted_kwargs = set(inspect.signature(func).parameters.keys()) - extra_kwargs = {} - for key, value in kwargs.items(): - if key in accepted_kwargs and key not in exclude_kwargs: - extra_kwargs[key] = value - - return extra_kwargs - - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): - - # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) - - - # Perform scheduler step using the predicted output - block_state.latents_dtype = block_state.latents.dtype - block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **block_state.scheduler_step_kwargs, return_dict=False)[0] - - if block_state.latents.dtype != block_state.latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - block_state.latents = block_state.latents.to(block_state.latents_dtype) - - return components, block_state - -# loop step (3): scheduler step to update latents (with inpainting) -class StableDiffusionXLInpaintDenoiseLoopAfterDenoiser(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("scheduler", EulerDiscreteScheduler), - ComponentSpec("unet", UNet2DConditionModel), - ] - - @property - def description(self) -> str: - return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("generator"), - InputParam("eta", default=0.0), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "mask", - type_hint=Optional[torch.Tensor], - description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "noise", - type_hint=Optional[torch.Tensor], - description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." - ), - InputParam( - "image_latents", - type_hint=Optional[torch.Tensor], - description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." - ), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - - @staticmethod - def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): - - accepted_kwargs = set(inspect.signature(func).parameters.keys()) - extra_kwargs = {} - for key, value in kwargs.items(): - if key in accepted_kwargs and key not in exclude_kwargs: - extra_kwargs[key] = value - - return extra_kwargs - - def check_inputs(self, components, block_state): - if components.num_channels_unet == 4: - if block_state.image_latents is None: - raise ValueError(f"image_latents is required for this step {self.__class__.__name__}") - if block_state.mask is None: - raise ValueError(f"mask is required for this step {self.__class__.__name__}") - if block_state.noise is None: - raise ValueError(f"noise is required for this step {self.__class__.__name__}") - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): - - self.check_inputs(components, block_state) - - # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) - - - # Perform scheduler step using the predicted output - block_state.latents_dtype = block_state.latents.dtype - block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **block_state.scheduler_step_kwargs, return_dict=False)[0] - - if block_state.latents.dtype != block_state.latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - block_state.latents = block_state.latents.to(block_state.latents_dtype) - - # adjust latent for inpainting - if components.num_channels_unet == 4: - block_state.init_latents_proper = block_state.image_latents - if i < len(block_state.timesteps) - 1: - block_state.noise_timestep = block_state.timesteps[i + 1] - block_state.init_latents_proper = components.scheduler.add_noise( - block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) - ) - - block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents - - - - return components, block_state - - -# the loop wrapper that iterates over the timesteps -class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks): - - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return ( - "Step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process" - ) - - @property - def loop_expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config"), - ComponentSpec("scheduler", EulerDiscreteScheduler), - ComponentSpec("unet", UNet2DConditionModel), - ] - - @property - def loop_intermediates_inputs(self) -> List[InputParam]: - return [ - InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." - ), - ] - - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False - if block_state.disable_guidance: - components.guider.disable() - else: - components.guider.enable() - - block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) - - with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: - for i, t in enumerate(block_state.timesteps): - components, block_state = self.loop_step(components, block_state, i=i, t=t) - if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): - progress_bar.update() - - self.add_block_state(state, block_state) - - return components, state - - -# composing the denoising loops -class StableDiffusionXLDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): - block_classes = [StableDiffusionXLDenoiseLoopBeforeDenoiser, StableDiffusionXLDenoiseLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser] - block_names = ["before_denoiser", "denoiser", "after_denoiser"] - -# control_cond -class StableDiffusionXLControlNetDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): - block_classes = [StableDiffusionXLDenoiseLoopBeforeDenoiser, StableDiffusionXLControlNetDenoiseLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser] - block_names = ["before_denoiser", "denoiser", "after_denoiser"] - -# mask -class StableDiffusionXLInpaintDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): - block_classes = [StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser, StableDiffusionXLDenoiseLoopDenoiser, StableDiffusionXLInpaintDenoiseLoopAfterDenoiser] - block_names = ["before_denoiser", "denoiser", "after_denoiser"] - -# control_cond + mask -class StableDiffusionXLInpaintControlNetDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): - block_classes = [StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser, StableDiffusionXLControlNetDenoiseLoopDenoiser, StableDiffusionXLInpaintDenoiseLoopAfterDenoiser] - block_names = ["before_denoiser", "denoiser", "after_denoiser"] - - - -# all task without controlnet -class StableDiffusionXLDenoiseStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLInpaintDenoiseLoop, StableDiffusionXLDenoiseLoop] - block_names = ["inpaint_denoise", "denoise"] - block_trigger_inputs = ["mask", None] - -# all task with controlnet -class StableDiffusionXLControlNetDenoiseStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLInpaintControlNetDenoiseLoop, StableDiffusionXLControlNetDenoiseLoop] - block_names = ["inpaint_controlnet_denoise", "controlnet_denoise"] - block_trigger_inputs = ["mask", None] - -# all task with or without controlnet -class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLControlNetDenoiseStep, StableDiffusionXLDenoiseStep] - block_names = ["controlnet_denoise", "denoise"] - block_trigger_inputs = ["controlnet_cond", None] - - - - - - - -# YiYi Notes: alternatively, this is you can just write the denoise loop using a pipeline block, easier but not composible -# class StableDiffusionXLDenoiseStep(PipelineBlock): - -# model_name = "stable-diffusion-xl" - -# @property -# def expected_components(self) -> List[ComponentSpec]: -# return [ -# ComponentSpec( -# "guider", -# ClassifierFreeGuidance, -# config=FrozenDict({"guidance_scale": 7.5}), -# default_creation_method="from_config"), -# ComponentSpec("scheduler", EulerDiscreteScheduler), -# ComponentSpec("unet", UNet2DConditionModel), -# ] - -# @property -# def description(self) -> str: -# return ( -# "Step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process" -# ) - -# @property -# def inputs(self) -> List[Tuple[str, Any]]: -# return [ -# InputParam("cross_attention_kwargs"), -# InputParam("generator"), -# InputParam("eta", default=0.0), -# InputParam("num_images_per_prompt", default=1), -# ] - -# @property -# def intermediates_inputs(self) -> List[str]: -# return [ -# InputParam( -# "latents", -# required=True, -# type_hint=torch.Tensor, -# description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." -# ), -# InputParam( -# "batch_size", -# required=True, -# type_hint=int, -# description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." -# ), -# InputParam( -# "timesteps", -# required=True, -# type_hint=torch.Tensor, -# description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." -# ), -# InputParam( -# "num_inference_steps", -# required=True, -# type_hint=int, -# description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." -# ), -# InputParam( -# "pooled_prompt_embeds", -# required=True, -# type_hint=torch.Tensor, -# description="The pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "negative_pooled_prompt_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " -# ), -# InputParam( -# "add_time_ids", -# required=True, -# type_hint=torch.Tensor, -# description="The time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." -# ), -# InputParam( -# "negative_add_time_ids", -# type_hint=Optional[torch.Tensor], -# description="The negative time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." -# ), -# InputParam( -# "prompt_embeds", -# required=True, -# type_hint=torch.Tensor, -# description="The prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "negative_prompt_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " -# ), -# InputParam( -# "timestep_cond", -# type_hint=Optional[torch.Tensor], -# description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." -# ), -# InputParam( -# "mask", -# type_hint=Optional[torch.Tensor], -# description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "masked_image_latents", -# type_hint=Optional[torch.Tensor], -# description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "noise", -# type_hint=Optional[torch.Tensor], -# description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." -# ), -# InputParam( -# "image_latents", -# type_hint=Optional[torch.Tensor], -# description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "ip_adapter_embeds", -# type_hint=Optional[torch.Tensor], -# description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." -# ), -# InputParam( -# "negative_ip_adapter_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." -# ), -# ] - -# @property -# def intermediates_outputs(self) -> List[OutputParam]: -# return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - - -# @staticmethod -# def check_inputs(components, block_state): - -# num_channels_unet = components.unet.config.in_channels -# if num_channels_unet == 9: -# # default case for runwayml/stable-diffusion-inpainting -# if block_state.mask is None or block_state.masked_image_latents is None: -# raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") -# num_channels_latents = block_state.latents.shape[1] -# num_channels_mask = block_state.mask.shape[1] -# num_channels_masked_image = block_state.masked_image_latents.shape[1] -# if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: -# raise ValueError( -# f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" -# f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" -# f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" -# f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" -# " `components.unet` or your `mask_image` or `image` input." -# ) - -# # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components -# @staticmethod -# def prepare_extra_step_kwargs(components, generator, eta): -# # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature -# # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. -# # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 -# # and should be between [0, 1] - -# accepts_eta = "eta" in set(inspect.signature(components.scheduler.step).parameters.keys()) -# extra_step_kwargs = {} -# if accepts_eta: -# extra_step_kwargs["eta"] = eta - -# # check if the scheduler accepts generator -# accepts_generator = "generator" in set(inspect.signature(components.scheduler.step).parameters.keys()) -# if accepts_generator: -# extra_step_kwargs["generator"] = generator -# return extra_step_kwargs - -# @torch.no_grad() -# def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - -# block_state = self.get_block_state(state) -# self.check_inputs(components, block_state) - -# block_state.num_channels_unet = components.unet.config.in_channels -# block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False -# if block_state.disable_guidance: -# components.guider.disable() -# else: -# components.guider.enable() - -# # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline -# block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) -# block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) - -# components.guider.set_input_fields( -# prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), -# add_time_ids=("add_time_ids", "negative_add_time_ids"), -# pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), -# ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), -# ) - -# with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: -# for i, t in enumerate(block_state.timesteps): -# components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) -# guider_data = components.guider.prepare_inputs(block_state) - -# block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) - -# # Prepare for inpainting -# if block_state.num_channels_unet == 9: -# block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) - -# for batch in guider_data: -# components.guider.prepare_models(components.unet) - -# # Prepare additional conditionings -# batch.added_cond_kwargs = { -# "text_embeds": batch.pooled_prompt_embeds, -# "time_ids": batch.add_time_ids, -# } -# if batch.ip_adapter_embeds is not None: -# batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds - -# # Predict the noise residual -# batch.noise_pred = components.unet( -# block_state.scaled_latents, -# t, -# encoder_hidden_states=batch.prompt_embeds, -# timestep_cond=block_state.timestep_cond, -# cross_attention_kwargs=block_state.cross_attention_kwargs, -# added_cond_kwargs=batch.added_cond_kwargs, -# return_dict=False, -# )[0] -# components.guider.cleanup_models(components.unet) - -# # Perform guidance -# block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_data) - -# # Perform scheduler step using the predicted output -# block_state.latents_dtype = block_state.latents.dtype -# block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] - -# if block_state.latents.dtype != block_state.latents_dtype: -# if torch.backends.mps.is_available(): -# # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 -# block_state.latents = block_state.latents.to(block_state.latents_dtype) - -# if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: -# block_state.init_latents_proper = block_state.image_latents -# if i < len(block_state.timesteps) - 1: -# block_state.noise_timestep = block_state.timesteps[i + 1] -# block_state.init_latents_proper = components.scheduler.add_noise( -# block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) -# ) - -# block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents - -# if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): -# progress_bar.update() - -# self.add_block_state(state, block_state) - -# return components, state - - - -# class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): - -# model_name = "stable-diffusion-xl" - -# @property -# def expected_components(self) -> List[ComponentSpec]: -# return [ -# ComponentSpec( -# "guider", -# ClassifierFreeGuidance, -# config=FrozenDict({"guidance_scale": 7.5}), -# default_creation_method="from_config"), -# ComponentSpec("scheduler", EulerDiscreteScheduler), -# ComponentSpec("unet", UNet2DConditionModel), -# ComponentSpec("controlnet", ControlNetModel), -# ] - -# @property -# def description(self) -> str: -# return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" - -# @property -# def inputs(self) -> List[Tuple[str, Any]]: -# return [ -# InputParam("num_images_per_prompt", default=1), -# InputParam("cross_attention_kwargs"), -# InputParam("generator"), -# InputParam("eta", default=0.0), -# InputParam("controlnet_conditioning_scale", type_hint=float, default=1.0), # can expect either input or intermediate input, (intermediate input if both are passed) -# ] - -# @property -# def intermediates_inputs(self) -> List[str]: -# return [ -# InputParam( -# "controlnet_cond", -# required=True, -# type_hint=torch.Tensor, -# description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "control_guidance_start", -# required=True, -# type_hint=float, -# description="The control guidance start value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "control_guidance_end", -# required=True, -# type_hint=float, -# description="The control guidance end value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "conditioning_scale", -# type_hint=float, -# description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "guess_mode", -# required=True, -# type_hint=bool, -# description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "controlnet_keep", -# required=True, -# type_hint=List[float], -# description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "latents", -# required=True, -# type_hint=torch.Tensor, -# description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." -# ), -# InputParam( -# "batch_size", -# required=True, -# type_hint=int, -# description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." -# ), -# InputParam( -# "timesteps", -# required=True, -# type_hint=torch.Tensor, -# description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." -# ), -# InputParam( -# "prompt_embeds", -# required=True, -# type_hint=torch.Tensor, -# description="The prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "negative_prompt_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "add_time_ids", -# required=True, -# type_hint=torch.Tensor, -# description="The time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." -# ), -# InputParam( -# "negative_add_time_ids", -# type_hint=Optional[torch.Tensor], -# description="The negative time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." -# ), -# InputParam( -# "pooled_prompt_embeds", -# required=True, -# type_hint=torch.Tensor, -# description="The pooled prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "negative_pooled_prompt_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "timestep_cond", -# type_hint=Optional[torch.Tensor], -# description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" -# ), -# InputParam( -# "mask", -# type_hint=Optional[torch.Tensor], -# description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "masked_image_latents", -# type_hint=Optional[torch.Tensor], -# description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "noise", -# type_hint=Optional[torch.Tensor], -# description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." -# ), -# InputParam( -# "image_latents", -# type_hint=Optional[torch.Tensor], -# description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "crops_coords", -# type_hint=Optional[Tuple[int]], -# description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." -# ), -# InputParam( -# "ip_adapter_embeds", -# type_hint=Optional[torch.Tensor], -# description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." -# ), -# InputParam( -# "negative_ip_adapter_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." -# ), -# InputParam( -# "num_inference_steps", -# required=True, -# type_hint=int, -# description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." -# ), -# InputParam(kwargs_type="controlnet_kwargs", description="additional kwargs for controlnet") -# ] - -# @property -# def intermediates_outputs(self) -> List[OutputParam]: -# return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - -# @staticmethod -# def check_inputs(components, block_state): - -# num_channels_unet = components.unet.config.in_channels -# if num_channels_unet == 9: -# # default case for runwayml/stable-diffusion-inpainting -# if block_state.mask is None or block_state.masked_image_latents is None: -# raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") -# num_channels_latents = block_state.latents.shape[1] -# num_channels_mask = block_state.mask.shape[1] -# num_channels_masked_image = block_state.masked_image_latents.shape[1] -# if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: -# raise ValueError( -# f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" -# f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" -# f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" -# f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" -# " `components.unet` or your `mask_image` or `image` input." -# ) -# @staticmethod -# def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): - -# accepted_kwargs = set(inspect.signature(func).parameters.keys()) -# extra_kwargs = {} -# for key, value in kwargs.items(): -# if key in accepted_kwargs and key not in exclude_kwargs: -# extra_kwargs[key] = value - -# return extra_kwargs - - -# @torch.no_grad() -# def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - -# block_state = self.get_block_state(state) -# self.check_inputs(components, block_state) -# block_state.device = components._execution_device -# print(f" block_state: {block_state}") - -# controlnet = unwrap_module(components.controlnet) - -# # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline -# block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) -# block_state.extra_controlnet_kwargs = self.prepare_extra_kwargs(controlnet.forward, exclude_kwargs=["controlnet_cond", "conditioning_scale", "guess_mode"], **block_state.controlnet_kwargs) - -# block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) - -# # (1) setup guider -# # disable for LCMs -# block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False -# if block_state.disable_guidance: -# components.guider.disable() -# else: -# components.guider.enable() -# components.guider.set_input_fields( -# prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), -# add_time_ids=("add_time_ids", "negative_add_time_ids"), -# pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), -# ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), -# ) - -# # (5) Denoise loop -# with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: -# for i, t in enumerate(block_state.timesteps): - -# # prepare latent input for unet -# block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) -# # adjust latent input for inpainting -# block_state.num_channels_unet = components.unet.config.in_channels -# if block_state.num_channels_unet == 9: -# block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) - - -# # cond_scale (controlnet input) -# if isinstance(block_state.controlnet_keep[i], list): -# block_state.cond_scale = [c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i])] -# else: -# block_state.controlnet_cond_scale = block_state.conditioning_scale -# if isinstance(block_state.controlnet_cond_scale, list): -# block_state.controlnet_cond_scale = block_state.controlnet_cond_scale[0] -# block_state.cond_scale = block_state.controlnet_cond_scale * block_state.controlnet_keep[i] - -# # default controlnet output/unet input for guess mode + conditional path -# block_state.down_block_res_samples_zeros = None -# block_state.mid_block_res_sample_zeros = None - -# # guided denoiser step -# components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) -# guider_state = components.guider.prepare_inputs(block_state) - -# for guider_state_batch in guider_state: -# components.guider.prepare_models(components.unet) - -# # Prepare additional conditionings -# guider_state_batch.added_cond_kwargs = { -# "text_embeds": guider_state_batch.pooled_prompt_embeds, -# "time_ids": guider_state_batch.add_time_ids, -# } -# if guider_state_batch.ip_adapter_embeds is not None: -# guider_state_batch.added_cond_kwargs["image_embeds"] = guider_state_batch.ip_adapter_embeds - -# # Prepare controlnet additional conditionings -# guider_state_batch.controlnet_added_cond_kwargs = { -# "text_embeds": guider_state_batch.pooled_prompt_embeds, -# "time_ids": guider_state_batch.add_time_ids, -# } - -# if block_state.guess_mode and not components.guider.is_conditional: -# # guider always run uncond batch first, so these tensors should be set already -# guider_state_batch.down_block_res_samples = block_state.down_block_res_samples_zeros -# guider_state_batch.mid_block_res_sample = block_state.mid_block_res_sample_zeros -# else: -# guider_state_batch.down_block_res_samples, guider_state_batch.mid_block_res_sample = components.controlnet( -# block_state.scaled_latents, -# t, -# encoder_hidden_states=guider_state_batch.prompt_embeds, -# controlnet_cond=block_state.controlnet_cond, -# conditioning_scale=block_state.conditioning_scale, -# guess_mode=block_state.guess_mode, -# added_cond_kwargs=guider_state_batch.controlnet_added_cond_kwargs, -# return_dict=False, -# **block_state.extra_controlnet_kwargs, -# ) - -# if block_state.down_block_res_samples_zeros is None: -# block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in guider_state_batch.down_block_res_samples] -# if block_state.mid_block_res_sample_zeros is None: -# block_state.mid_block_res_sample_zeros = torch.zeros_like(guider_state_batch.mid_block_res_sample) - - - -# guider_state_batch.noise_pred = components.unet( -# block_state.scaled_latents, -# t, -# encoder_hidden_states=guider_state_batch.prompt_embeds, -# timestep_cond=block_state.timestep_cond, -# cross_attention_kwargs=block_state.cross_attention_kwargs, -# added_cond_kwargs=guider_state_batch.added_cond_kwargs, -# down_block_additional_residuals=guider_state_batch.down_block_res_samples, -# mid_block_additional_residual=guider_state_batch.mid_block_res_sample, -# return_dict=False, -# )[0] -# components.guider.cleanup_models(components.unet) - -# # Perform guidance -# block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_state) - -# # Perform scheduler step using the predicted output -# block_state.latents_dtype = block_state.latents.dtype -# block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] - -# if block_state.latents.dtype != block_state.latents_dtype: -# if torch.backends.mps.is_available(): -# # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 -# block_state.latents = block_state.latents.to(block_state.latents_dtype) - -# # adjust latent for inpainting -# if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: -# block_state.init_latents_proper = block_state.image_latents -# if i < len(block_state.timesteps) - 1: -# block_state.noise_timestep = block_state.timesteps[i + 1] -# block_state.init_latents_proper = components.scheduler.add_noise( -# block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) -# ) - -# block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents - -# if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): -# progress_bar.update() - -# self.add_block_state(state, block_state) - -# return components, state \ No newline at end of file From 0acb5e1460b2fd2769bcaa38a523c6a9a9f063ea Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 10 May 2025 03:50:31 +0200 Subject: [PATCH 16/38] made a modular_pipelines folder! --- src/diffusers/modular_pipelines/__init__.py | 82 + .../modular_pipelines/components_manager.py | 863 ++++++++ .../modular_pipelines/modular_pipeline.py | 1916 +++++++++++++++++ .../modular_pipeline_utils.py | 598 +++++ .../stable_diffusion_xl/__init__.py | 51 + .../stable_diffusion_xl/after_denoise.py | 259 +++ .../stable_diffusion_xl/before_denoise.py | 1766 +++++++++++++++ .../stable_diffusion_xl/denoise.py | 1362 ++++++++++++ .../stable_diffusion_xl/encoders.py | 856 ++++++++ .../stable_diffusion_xl/modular_loader.py | 175 ++ .../modular_pipeline_presets.py | 119 + 11 files changed, 8047 insertions(+) create mode 100644 src/diffusers/modular_pipelines/__init__.py create mode 100644 src/diffusers/modular_pipelines/components_manager.py create mode 100644 src/diffusers/modular_pipelines/modular_pipeline.py create mode 100644 src/diffusers/modular_pipelines/modular_pipeline_utils.py create mode 100644 src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py create mode 100644 src/diffusers/modular_pipelines/stable_diffusion_xl/after_denoise.py create mode 100644 src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py create mode 100644 src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py create mode 100644 src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py create mode 100644 src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py create mode 100644 src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py new file mode 100644 index 000000000000..cb2ed78ce360 --- /dev/null +++ b/src/diffusers/modular_pipelines/__init__.py @@ -0,0 +1,82 @@ +from typing import TYPE_CHECKING + +from ..utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +# These modules contain pipelines from multiple libraries/frameworks +_dummy_objects = {} +_import_structure = {} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_pt_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_pt_objects)) +else: + _import_structure["modular_pipeline"] = [ + "ModularPipelineMixin", + "PipelineBlock", + "AutoPipelineBlocks", + "SequentialPipelineBlocks", + "LoopSequentialPipelineBlocks", + "ModularLoader", + "PipelineState", + "BlockState", + ] + _import_structure["modular_pipeline_utils"] = [ + "ComponentSpec", + "ConfigSpec", + "InputParam", + "OutputParam", + ] + _import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoPipeline", "StableDiffusionXLModularLoader"] + _import_structure["components_manager"] = ["ComponentsManager"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_pt_objects import * # noqa F403 + else: + from .modular_pipeline import ( + AutoPipelineBlocks, + BlockState, + LoopSequentialPipelineBlocks, + ModularLoader, + ModularPipelineMixin, + PipelineBlock, + PipelineState, + SequentialPipelineBlocks, + ) + from .modular_pipeline_utils import ( + ComponentSpec, + ConfigSpec, + InputParam, + OutputParam, + ) + from .stable_diffusion_xl import ( + StableDiffusionXLAutoPipeline, + StableDiffusionXLModularLoader, + ) + from .components_manager import ComponentsManager +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/modular_pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py new file mode 100644 index 000000000000..0ace1b321e8b --- /dev/null +++ b/src/diffusers/modular_pipelines/components_manager.py @@ -0,0 +1,863 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict +from itertools import combinations +from typing import List, Optional, Union, Dict, Any +import copy + +import torch +import time +from dataclasses import dataclass + +from ..utils import ( + is_accelerate_available, + logging, +) +from ..models.modeling_utils import ModelMixin +from .modular_pipeline_utils import ComponentSpec + + +import uuid + + +if is_accelerate_available(): + from accelerate.hooks import ModelHook, add_hook_to_module, remove_hook_from_module + from accelerate.state import PartialState + from accelerate.utils import send_to_device + from accelerate.utils.memory import clear_device_cache + from accelerate.utils.modeling import convert_file_size_to_int + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# YiYi Notes: copied from modeling_utils.py (decide later where to put this) +def get_memory_footprint(self, return_buffers=True): + r""" + Get the memory footprint of a model. This will return the memory footprint of the current model in bytes. Useful to + benchmark the memory footprint of the current model and design some tests. Solution inspired from the PyTorch + discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2 + + Arguments: + return_buffers (`bool`, *optional*, defaults to `True`): + Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers are + tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch norm + layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2 + """ + mem = sum([param.nelement() * param.element_size() for param in self.parameters()]) + if return_buffers: + mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()]) + mem = mem + mem_bufs + return mem + + +class CustomOffloadHook(ModelHook): + """ + A hook that offloads a model on the CPU until its forward pass is called. It ensures the model and its inputs are + on the given device. Optionally offloads other models to the CPU before the forward pass is called. + + Args: + execution_device(`str`, `int` or `torch.device`, *optional*): + The device on which the model should be executed. Will default to the MPS device if it's available, then + GPU 0 if there is a GPU, and finally to the CPU. + """ + + def __init__( + self, + execution_device: Optional[Union[str, int, torch.device]] = None, + other_hooks: Optional[List["UserCustomOffloadHook"]] = None, + offload_strategy: Optional["AutoOffloadStrategy"] = None, + ): + self.execution_device = execution_device if execution_device is not None else PartialState().default_device + self.other_hooks = other_hooks + self.offload_strategy = offload_strategy + self.model_id = None + + def set_strategy(self, offload_strategy: "AutoOffloadStrategy"): + self.offload_strategy = offload_strategy + + def add_other_hook(self, hook: "UserCustomOffloadHook"): + """ + Add a hook to the list of hooks to consider for offloading. + """ + if self.other_hooks is None: + self.other_hooks = [] + self.other_hooks.append(hook) + + def init_hook(self, module): + return module.to("cpu") + + def pre_forward(self, module, *args, **kwargs): + if module.device != self.execution_device: + if self.other_hooks is not None: + hooks_to_offload = [hook for hook in self.other_hooks if hook.model.device == self.execution_device] + # offload all other hooks + start_time = time.perf_counter() + if self.offload_strategy is not None: + hooks_to_offload = self.offload_strategy( + hooks=hooks_to_offload, + model_id=self.model_id, + model=module, + execution_device=self.execution_device, + ) + end_time = time.perf_counter() + logger.info( + f" time taken to apply offload strategy for {self.model_id}: {(end_time - start_time):.2f} seconds" + ) + + for hook in hooks_to_offload: + logger.info( + f"moving {self.model_id} to {self.execution_device}, offloading {hook.model_id} to cpu" + ) + hook.offload() + + if hooks_to_offload: + clear_device_cache() + module.to(self.execution_device) + return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device) + + +class UserCustomOffloadHook: + """ + A simple hook grouping a model and a `CustomOffloadHook`, which provides easy APIs for to call the init method of + the hook or remove it entirely. + """ + + def __init__(self, model_id, model, hook): + self.model_id = model_id + self.model = model + self.hook = hook + + def offload(self): + self.hook.init_hook(self.model) + + def attach(self): + add_hook_to_module(self.model, self.hook) + self.hook.model_id = self.model_id + + def remove(self): + remove_hook_from_module(self.model) + self.hook.model_id = None + + def add_other_hook(self, hook: "UserCustomOffloadHook"): + self.hook.add_other_hook(hook) + + +def custom_offload_with_hook( + model_id: str, + model: torch.nn.Module, + execution_device: Union[str, int, torch.device] = None, + offload_strategy: Optional["AutoOffloadStrategy"] = None, +): + hook = CustomOffloadHook(execution_device=execution_device, offload_strategy=offload_strategy) + user_hook = UserCustomOffloadHook(model_id=model_id, model=model, hook=hook) + user_hook.attach() + return user_hook + + +class AutoOffloadStrategy: + """ + Offload strategy that should be used with `CustomOffloadHook` to automatically offload models to the CPU based on + the available memory on the device. + """ + + def __init__(self, memory_reserve_margin="3GB"): + self.memory_reserve_margin = convert_file_size_to_int(memory_reserve_margin) + + def __call__(self, hooks, model_id, model, execution_device): + if len(hooks) == 0: + return [] + + current_module_size = get_memory_footprint(model) + + mem_on_device = torch.cuda.mem_get_info(execution_device.index)[0] + mem_on_device = mem_on_device - self.memory_reserve_margin + if current_module_size < mem_on_device: + return [] + + min_memory_offload = current_module_size - mem_on_device + logger.info(f" search for models to offload in order to free up {min_memory_offload / 1024**3:.2f} GB memory") + + # exlucde models that's not currently loaded on the device + module_sizes = dict( + sorted( + {hook.model_id: get_memory_footprint(hook.model) for hook in hooks}.items(), + key=lambda x: x[1], + reverse=True, + ) + ) + + def search_best_candidate(module_sizes, min_memory_offload): + """ + search the optimal combination of models to offload to cpu, given a dictionary of module sizes and a + minimum memory offload size. the combination of models should add up to the smallest modulesize that is + larger than `min_memory_offload` + """ + model_ids = list(module_sizes.keys()) + best_candidate = None + best_size = float("inf") + for r in range(1, len(model_ids) + 1): + for candidate_model_ids in combinations(model_ids, r): + candidate_size = sum( + module_sizes[candidate_model_id] for candidate_model_id in candidate_model_ids + ) + if candidate_size < min_memory_offload: + continue + else: + if best_candidate is None or candidate_size < best_size: + best_candidate = candidate_model_ids + best_size = candidate_size + + return best_candidate + + best_offload_model_ids = search_best_candidate(module_sizes, min_memory_offload) + + if best_offload_model_ids is None: + # if no combination is found, meaning that we cannot meet the memory requirement, offload all models + logger.warning("no combination of models to offload to cpu is found, offloading all models") + hooks_to_offload = hooks + else: + hooks_to_offload = [hook for hook in hooks if hook.model_id in best_offload_model_ids] + + return hooks_to_offload + + + +class ComponentsManager: + def __init__(self): + self.components = OrderedDict() + self.added_time = OrderedDict() # Store when components were added + self.collections = OrderedDict() # collection_name -> set of component_names + self.model_hooks = None + self._auto_offload_enabled = False + + + def _get_by_collection(self, collection: str): + """ + Select components by collection name. + """ + selected_components = {} + if collection in self.collections: + component_ids = self.collections[collection] + for component_id in component_ids: + selected_components[component_id] = self.components[component_id] + return selected_components + + + def _get_by_load_id(self, load_id: str): + """ + Select components by its load_id. + """ + selected_components = {} + for name, component in self.components.items(): + if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id == load_id: + selected_components[name] = component + return selected_components + + + def add(self, name, component, collection: Optional[str] = None): + + for comp_id, comp in self.components.items(): + if comp == component: + logger.warning(f"Component '{name}' already exists in ComponentsManager") + return comp_id + + component_id = f"{name}_{uuid.uuid4()}" + + if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null": + components_with_same_load_id = self._get_by_load_id(component._diffusers_load_id) + if components_with_same_load_id: + existing = ", ".join(components_with_same_load_id.keys()) + logger.warning( + f"Component '{name}' has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. " + f"To remove a duplicate, call `components_manager.remove('')`." + ) + + + # add component to components manager + self.components[component_id] = component + self.added_time[component_id] = time.time() + if collection: + if collection not in self.collections: + self.collections[collection] = set() + self.collections[collection].add(component_id) + + if self._auto_offload_enabled: + self.enable_auto_cpu_offload(self._auto_offload_device) + + logger.info(f"Added component '{name}' to ComponentsManager as '{component_id}'") + return component_id + + + def remove(self, name: Union[str, List[str]]): + + if name not in self.components: + logger.warning(f"Component '{name}' not found in ComponentsManager") + return + + self.components.pop(name) + self.added_time.pop(name) + + for collection in self.collections: + if name in self.collections[collection]: + self.collections[collection].remove(name) + + if self._auto_offload_enabled: + self.enable_auto_cpu_offload(self._auto_offload_device) + + def get(self, names: Union[str, List[str]] = None, collection: Optional[str] = None, load_id: Optional[str] = None, + as_name_component_tuples: bool = False): + """ + Select components by name with simple pattern matching. + + Args: + names: Component name(s) or pattern(s) + Patterns: + - "unet" : match any component with base name "unet" (e.g., unet_123abc) + - "!unet" : everything except components with base name "unet" + - "unet*" : anything with base name starting with "unet" + - "!unet*" : anything with base name NOT starting with "unet" + - "*unet*" : anything with base name containing "unet" + - "!*unet*" : anything with base name NOT containing "unet" + - "refiner|vae|unet" : anything with base name exactly matching "refiner", "vae", or "unet" + - "!refiner|vae|unet" : anything with base name NOT exactly matching "refiner", "vae", or "unet" + - "unet*|vae*" : anything with base name starting with "unet" OR starting with "vae" + collection: Optional collection to filter by + load_id: Optional load_id to filter by + as_name_component_tuples: If True, returns a list of (name, component) tuples using base names + instead of a dictionary with component IDs as keys + + Returns: + Dictionary mapping component IDs to components, + or list of (base_name, component) tuples if as_name_component_tuples=True + """ + + if collection: + if collection not in self.collections: + logger.warning(f"Collection '{collection}' not found in ComponentsManager") + return [] if as_name_component_tuples else {} + components = self._get_by_collection(collection) + else: + components = self.components + + if load_id: + components = self._get_by_load_id(load_id) + + # Helper to extract base name from component_id + def get_base_name(component_id): + parts = component_id.split('_') + # If the last part looks like a UUID, remove it + if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]: + return '_'.join(parts[:-1]) + return component_id + + if names is None: + if as_name_component_tuples: + return [(get_base_name(comp_id), comp) for comp_id, comp in components.items()] + else: + return components + + # Create mapping from component_id to base_name for all components + base_names = {comp_id: get_base_name(comp_id) for comp_id in components.keys()} + + def matches_pattern(component_id, pattern, exact_match=False): + """ + Helper function to check if a component matches a pattern based on its base name. + + Args: + component_id: The component ID to check + pattern: The pattern to match against + exact_match: If True, only exact matches to base_name are considered + """ + base_name = base_names[component_id] + + # Exact match with base name + if exact_match: + return pattern == base_name + + # Prefix match (ends with *) + elif pattern.endswith('*'): + prefix = pattern[:-1] + return base_name.startswith(prefix) + + # Contains match (starts with *) + elif pattern.startswith('*'): + search = pattern[1:-1] if pattern.endswith('*') else pattern[1:] + return search in base_name + + # Exact match (no wildcards) + else: + return pattern == base_name + + if isinstance(names, str): + # Check if this is a "not" pattern + is_not_pattern = names.startswith('!') + if is_not_pattern: + names = names[1:] # Remove the ! prefix + + # Handle OR patterns (containing |) + if '|' in names: + terms = names.split('|') + matches = {} + + for comp_id, comp in components.items(): + # For OR patterns with exact names (no wildcards), we do exact matching on base names + exact_match = all(not (term.startswith('*') or term.endswith('*')) for term in terms) + + # Check if any of the terms match this component + should_include = any(matches_pattern(comp_id, term, exact_match) for term in terms) + + # Flip the decision if this is a NOT pattern + if is_not_pattern: + should_include = not should_include + + if should_include: + matches[comp_id] = comp + + log_msg = "NOT " if is_not_pattern else "" + match_type = "exactly matching" if exact_match else "matching any of patterns" + logger.info(f"Getting components {log_msg}{match_type} {terms}: {list(matches.keys())}") + + # Try exact match with a base name + elif any(names == base_name for base_name in base_names.values()): + # Find all components with this base name + matches = { + comp_id: comp for comp_id, comp in components.items() + if (base_names[comp_id] == names) != is_not_pattern + } + + if is_not_pattern: + logger.info(f"Getting all components except those with base name '{names}': {list(matches.keys())}") + else: + logger.info(f"Getting components with base name '{names}': {list(matches.keys())}") + + # Prefix match (ends with *) + elif names.endswith('*'): + prefix = names[:-1] + matches = { + comp_id: comp for comp_id, comp in components.items() + if base_names[comp_id].startswith(prefix) != is_not_pattern + } + if is_not_pattern: + logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}") + else: + logger.info(f"Getting components starting with '{prefix}': {list(matches.keys())}") + + # Contains match (starts with *) + elif names.startswith('*'): + search = names[1:-1] if names.endswith('*') else names[1:] + matches = { + comp_id: comp for comp_id, comp in components.items() + if (search in base_names[comp_id]) != is_not_pattern + } + if is_not_pattern: + logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}") + else: + logger.info(f"Getting components containing '{search}': {list(matches.keys())}") + + # Substring match (no wildcards, but not an exact component name) + elif any(names in base_name for base_name in base_names.values()): + matches = { + comp_id: comp for comp_id, comp in components.items() + if (names in base_names[comp_id]) != is_not_pattern + } + if is_not_pattern: + logger.info(f"Getting components NOT containing '{names}': {list(matches.keys())}") + else: + logger.info(f"Getting components containing '{names}': {list(matches.keys())}") + + else: + raise ValueError(f"Component or pattern '{names}' not found in ComponentsManager") + + if not matches: + raise ValueError(f"No components found matching pattern '{names}'") + + if as_name_component_tuples: + return [(base_names[comp_id], comp) for comp_id, comp in matches.items()] + else: + return matches + + elif isinstance(names, list): + results = {} + for name in names: + result = self.get(name, collection, load_id, as_name_component_tuples=False) + results.update(result) + + if as_name_component_tuples: + return [(base_names[comp_id], comp) for comp_id, comp in results.items()] + else: + return results + + else: + raise ValueError(f"Invalid type for names: {type(names)}") + + def enable_auto_cpu_offload(self, device: Union[str, int, torch.device]="cuda", memory_reserve_margin="3GB"): + for name, component in self.components.items(): + if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"): + remove_hook_from_module(component, recurse=True) + + self.disable_auto_cpu_offload() + offload_strategy = AutoOffloadStrategy(memory_reserve_margin=memory_reserve_margin) + device = torch.device(device) + if device.index is None: + device = torch.device(f"{device.type}:{0}") + all_hooks = [] + for name, component in self.components.items(): + if isinstance(component, torch.nn.Module): + hook = custom_offload_with_hook(name, component, device, offload_strategy=offload_strategy) + all_hooks.append(hook) + + for hook in all_hooks: + other_hooks = [h for h in all_hooks if h is not hook] + for other_hook in other_hooks: + if other_hook.hook.execution_device == hook.hook.execution_device: + hook.add_other_hook(other_hook) + + self.model_hooks = all_hooks + self._auto_offload_enabled = True + self._auto_offload_device = device + + def disable_auto_cpu_offload(self): + if self.model_hooks is None: + self._auto_offload_enabled = False + return + + for hook in self.model_hooks: + hook.offload() + hook.remove() + if self.model_hooks: + clear_device_cache() + self.model_hooks = None + self._auto_offload_enabled = False + + # YiYi TODO: add quantization info + def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = None) -> Optional[Dict[str, Any]]: + """Get comprehensive information about a component. + + Args: + name: Name of the component to get info for + fields: Optional field(s) to return. Can be a string for single field or list of fields. + If None, returns all fields. + + Returns: + Dictionary containing requested component metadata. + If fields is specified, returns only those fields. + If a single field is requested as string, returns just that field's value. + """ + if name not in self.components: + raise ValueError(f"Component '{name}' not found in ComponentsManager") + + component = self.components[name] + + # Build complete info dict first + info = { + "model_id": name, + "added_time": self.added_time[name], + "collection": next((coll for coll, comps in self.collections.items() if name in comps), None), + } + + # Additional info for torch.nn.Module components + if isinstance(component, torch.nn.Module): + # Check for hook information + has_hook = hasattr(component, "_hf_hook") + execution_device = None + if has_hook and hasattr(component._hf_hook, "execution_device"): + execution_device = component._hf_hook.execution_device + + info.update({ + "class_name": component.__class__.__name__, + "size_gb": get_memory_footprint(component) / (1024**3), + "adapters": None, # Default to None + "has_hook": has_hook, + "execution_device": execution_device, + }) + + # Get adapters if applicable + if hasattr(component, "peft_config"): + info["adapters"] = list(component.peft_config.keys()) + + # Check for IP-Adapter scales + if hasattr(component, "_load_ip_adapter_weights") and hasattr(component, "attn_processors"): + processors = copy.deepcopy(component.attn_processors) + # First check if any processor is an IP-Adapter + processor_types = [v.__class__.__name__ for v in processors.values()] + if any("IPAdapter" in ptype for ptype in processor_types): + # Then get scales only from IP-Adapter processors + scales = { + k: v.scale + for k, v in processors.items() + if hasattr(v, "scale") and "IPAdapter" in v.__class__.__name__ + } + if scales: + info["ip_adapter"] = summarize_dict_by_value_and_parts(scales) + + # If fields specified, filter info + if fields is not None: + if isinstance(fields, str): + # Single field requested, return just that value + return {fields: info.get(fields)} + else: + # List of fields requested, return dict with just those fields + return {k: v for k, v in info.items() if k in fields} + + return info + + def __repr__(self): + # Helper to get simple name without UUID + def get_simple_name(name): + # Extract the base name by splitting on underscore and taking first part + # This assumes names are in format "name_uuid" + parts = name.split('_') + # If we have at least 2 parts and the last part looks like a UUID, remove it + if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]: + return '_'.join(parts[:-1]) + return name + + # Extract load_id if available + def get_load_id(component): + if hasattr(component, "_diffusers_load_id"): + return component._diffusers_load_id + return "N/A" + + # Format device info compactly + def format_device(component, info): + if not info["has_hook"]: + return str(getattr(component, 'device', 'N/A')) + else: + device = str(getattr(component, 'device', 'N/A')) + exec_device = str(info['execution_device'] or 'N/A') + return f"{device}({exec_device})" + + # Get all simple names to calculate width + simple_names = [get_simple_name(id) for id in self.components.keys()] + + # Get max length of load_ids for models + load_ids = [ + get_load_id(component) + for component in self.components.values() + if isinstance(component, torch.nn.Module) and hasattr(component, "_diffusers_load_id") + ] + max_load_id_len = max([15] + [len(str(lid)) for lid in load_ids]) if load_ids else 15 + + # Collection names + collection_names = [ + next((coll for coll, comps in self.collections.items() if name in comps), "N/A") + for name in self.components.keys() + ] + + col_widths = { + "name": max(15, max(len(name) for name in simple_names)), + "class": max(25, max(len(component.__class__.__name__) for component in self.components.values())), + "device": 15, # Reduced since using more compact format + "dtype": 15, + "size": 10, + "load_id": max_load_id_len, + "collection": max(10, max(len(str(c)) for c in collection_names)) + } + + # Create the header lines + sep_line = "=" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n" + dash_line = "-" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n" + + output = "Components:\n" + sep_line + + # Separate components into models and others + models = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)} + others = {k: v for k, v in self.components.items() if not isinstance(v, torch.nn.Module)} + + # Models section + if models: + output += "Models:\n" + dash_line + # Column headers + output += f"{'Name':<{col_widths['name']}} | {'Class':<{col_widths['class']}} | " + output += f"{'Device':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | " + output += f"{'Size (GB)':<{col_widths['size']}} | {'Load ID':<{col_widths['load_id']}} | Collection\n" + output += dash_line + + # Model entries + for name, component in models.items(): + info = self.get_model_info(name) + simple_name = get_simple_name(name) + device_str = format_device(component, info) + dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A" + load_id = get_load_id(component) + collection = info["collection"] or "N/A" + + output += f"{simple_name:<{col_widths['name']}} | {info['class_name']:<{col_widths['class']}} | " + output += f"{device_str:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | " + output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {collection}\n" + output += dash_line + + # Other components section + if others: + if models: # Add extra newline if we had models section + output += "\n" + output += "Other Components:\n" + dash_line + # Column headers for other components + output += f"{'Name':<{col_widths['name']}} | {'Class':<{col_widths['class']}} | Collection\n" + output += dash_line + + # Other component entries + for name, component in others.items(): + info = self.get_model_info(name) + simple_name = get_simple_name(name) + collection = info["collection"] or "N/A" + + output += f"{simple_name:<{col_widths['name']}} | {component.__class__.__name__:<{col_widths['class']}} | {collection}\n" + output += dash_line + + # Add additional component info + output += "\nAdditional Component Info:\n" + "=" * 50 + "\n" + for name in self.components: + info = self.get_model_info(name) + if info is not None and (info.get("adapters") is not None or info.get("ip_adapter")): + simple_name = get_simple_name(name) + output += f"\n{simple_name}:\n" + if info.get("adapters") is not None: + output += f" Adapters: {info['adapters']}\n" + if info.get("ip_adapter"): + output += f" IP-Adapter: Enabled\n" + output += f" Added Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(info['added_time']))}\n" + + return output + + def from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs): + """ + Load components from a pretrained model and add them to the manager. + + Args: + pretrained_model_name_or_path (str): The path or identifier of the pretrained model + prefix (str, optional): Prefix to add to all component names loaded from this model. + If provided, components will be named as "{prefix}_{component_name}" + **kwargs: Additional arguments to pass to DiffusionPipeline.from_pretrained() + """ + subfolder = kwargs.pop("subfolder", None) + # YiYi TODO: extend AutoModel to support non-diffusers models + if subfolder: + from ..models import AutoModel + component = AutoModel.from_pretrained(pretrained_model_name_or_path, subfolder=subfolder, **kwargs) + component_name = f"{prefix}_{subfolder}" if prefix else subfolder + if component_name not in self.components: + self.add(component_name, component) + else: + logger.warning( + f"Component '{component_name}' already exists in ComponentsManager and will not be added. To add it, either:\n" + f"1. remove the existing component with remove('{component_name}')\n" + f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')" + ) + else: + from ..pipelines.pipeline_utils import DiffusionPipeline + pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs) + for name, component in pipe.components.items(): + + if component is None: + continue + + # Add prefix if specified + component_name = f"{prefix}_{name}" if prefix else name + + if component_name not in self.components: + self.add(component_name, component) + else: + logger.warning( + f"Component '{component_name}' already exists in ComponentsManager and will not be added. To add it, either:\n" + f"1. remove the existing component with remove('{component_name}')\n" + f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')" + ) + + def get_one(self, name: Optional[str] = None, collection: Optional[str] = None, load_id: Optional[str] = None) -> Any: + """ + Get a single component by name. Raises an error if multiple components match or none are found. + + Args: + name: Component name or pattern + collection: Optional collection to filter by + load_id: Optional load_id to filter by + + Returns: + A single component + + Raises: + ValueError: If no components match or multiple components match + """ + results = self.get(name, collection, load_id) + + if not results: + raise ValueError(f"No components found matching '{name}'") + + if len(results) > 1: + raise ValueError(f"Multiple components found matching '{name}': {list(results.keys())}") + + return next(iter(results.values())) + +def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]: + """Summarizes a dictionary by finding common prefixes that share the same value. + + For a dictionary with dot-separated keys like: + { + 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': [0.6], + 'down_blocks.1.attentions.1.transformer_blocks.1.attn2.processor': [0.6], + 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': [0.3], + } + + Returns a dictionary where keys are the shortest common prefixes and values are their shared values: + { + 'down_blocks': [0.6], + 'up_blocks': [0.3] + } + """ + # First group by values - convert lists to tuples to make them hashable + value_to_keys = {} + for key, value in d.items(): + value_tuple = tuple(value) if isinstance(value, list) else value + if value_tuple not in value_to_keys: + value_to_keys[value_tuple] = [] + value_to_keys[value_tuple].append(key) + + def find_common_prefix(keys: List[str]) -> str: + """Find the shortest common prefix among a list of dot-separated keys.""" + if not keys: + return "" + if len(keys) == 1: + return keys[0] + + # Split all keys into parts + key_parts = [k.split('.') for k in keys] + + # Find how many initial parts are common + common_length = 0 + for parts in zip(*key_parts): + if len(set(parts)) == 1: # All parts at this position are the same + common_length += 1 + else: + break + + if common_length == 0: + return "" + + # Return the common prefix + return '.'.join(key_parts[0][:common_length]) + + # Create summary by finding common prefixes for each value group + summary = {} + for value_tuple, keys in value_to_keys.items(): + prefix = find_common_prefix(keys) + if prefix: # Only add if we found a common prefix + # Convert tuple back to list if it was originally a list + value = list(value_tuple) if isinstance(d[keys[0]], list) else value_tuple + summary[prefix] = value + else: + summary[""] = value # Use empty string if no common prefix + + return summary diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py new file mode 100644 index 000000000000..98960fe25bde --- /dev/null +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -0,0 +1,1916 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import traceback +import warnings +from collections import OrderedDict +from dataclasses import dataclass, field +from typing import Any, Dict, List, Tuple, Union, Optional, Type + + +import torch +from tqdm.auto import tqdm +import re +import os +import importlib + +from huggingface_hub.utils import validate_hf_hub_args + +from ..configuration_utils import ConfigMixin, FrozenDict +from ..utils import ( + is_accelerate_available, + is_accelerate_version, + logging, + PushToHubMixin, +) +from ..pipelines.pipeline_loading_utils import _get_pipeline_class, simple_get_class_obj, _fetch_class_library_tuple +from .modular_pipeline_utils import ( + ComponentSpec, + ConfigSpec, + InputParam, + OutputParam, + format_components, + format_configs, + format_input_params, + format_inputs_short, + format_intermediates_short, + format_output_params, + format_params, + make_doc_string, +) +from .components_manager import ComponentsManager + +from copy import deepcopy +if is_accelerate_available(): + import accelerate + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +MODULAR_LOADER_MAPPING = OrderedDict( + [ + ("stable-diffusion-xl", "StableDiffusionXLModularLoader"), + ] +) + + +@dataclass +class PipelineState: + """ + [`PipelineState`] stores the state of a pipeline. It is used to pass data between pipeline blocks. + """ + + inputs: Dict[str, Any] = field(default_factory=dict) + intermediates: Dict[str, Any] = field(default_factory=dict) + input_kwargs: Dict[str, list[str, Any]] = field(default_factory=dict) + intermediate_kwargs: Dict[str, list[str, Any]] = field(default_factory=dict) + + def add_input(self, key: str, value: Any, kwargs_type: str = None): + """ + Add an input to the pipeline state with optional metadata. + + Args: + key (str): The key for the input + value (Any): The input value + kwargs_type (str): The kwargs_type to store with the input + """ + self.inputs[key] = value + if kwargs_type is not None: + if kwargs_type not in self.input_kwargs: + self.input_kwargs[kwargs_type] = [key] + else: + self.input_kwargs[kwargs_type].append(key) + + def add_intermediate(self, key: str, value: Any, kwargs_type: str = None): + """ + Add an intermediate value to the pipeline state with optional metadata. + + Args: + key (str): The key for the intermediate value + value (Any): The intermediate value + kwargs_type (str): The kwargs_type to store with the intermediate value + """ + self.intermediates[key] = value + if kwargs_type is not None: + if kwargs_type not in self.intermediate_kwargs: + self.intermediate_kwargs[kwargs_type] = [key] + else: + self.intermediate_kwargs[kwargs_type].append(key) + + def get_input(self, key: str, default: Any = None) -> Any: + return self.inputs.get(key, default) + + def get_inputs(self, keys: List[str], default: Any = None) -> Dict[str, Any]: + return {key: self.inputs.get(key, default) for key in keys} + + def get_inputs_kwargs(self, kwargs_type: str) -> Dict[str, Any]: + """ + Get all inputs with matching kwargs_type. + + Args: + kwargs_type (str): The kwargs_type to filter by + + Returns: + Dict[str, Any]: Dictionary of inputs with matching kwargs_type + """ + input_names = self.input_kwargs.get(kwargs_type, []) + return self.get_inputs(input_names) + + def get_intermediates_kwargs(self, kwargs_type: str) -> Dict[str, Any]: + """ + Get all intermediates with matching kwargs_type. + + Args: + kwargs_type (str): The kwargs_type to filter by + + Returns: + Dict[str, Any]: Dictionary of intermediates with matching kwargs_type + """ + intermediate_names = self.intermediate_kwargs.get(kwargs_type, []) + return self.get_intermediates(intermediate_names) + + def get_intermediate(self, key: str, default: Any = None) -> Any: + return self.intermediates.get(key, default) + + def get_intermediates(self, keys: List[str], default: Any = None) -> Dict[str, Any]: + return {key: self.intermediates.get(key, default) for key in keys} + + def to_dict(self) -> Dict[str, Any]: + return {**self.__dict__, "inputs": self.inputs, "intermediates": self.intermediates} + + def __repr__(self): + def format_value(v): + if hasattr(v, "shape") and hasattr(v, "dtype"): + return f"Tensor(dtype={v.dtype}, shape={v.shape})" + elif isinstance(v, list) and len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): + return f"[Tensor(dtype={v[0].dtype}, shape={v[0].shape}), ...]" + else: + return repr(v) + + inputs = "\n".join(f" {k}: {format_value(v)}" for k, v in self.inputs.items()) + intermediates = "\n".join(f" {k}: {format_value(v)}" for k, v in self.intermediates.items()) + + # Format input_kwargs and intermediate_kwargs + input_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.input_kwargs.items()) + intermediate_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.intermediate_kwargs.items()) + + return ( + f"PipelineState(\n" + f" inputs={{\n{inputs}\n }},\n" + f" intermediates={{\n{intermediates}\n }},\n" + f" input_kwargs={{\n{input_kwargs_str}\n }},\n" + f" intermediate_kwargs={{\n{intermediate_kwargs_str}\n }}\n" + f")" + ) + + +@dataclass +class BlockState: + """ + Container for block state data with attribute access and formatted representation. + """ + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + def __getitem__(self, key: str): + # allows block_state["foo"] + return getattr(self, key, None) + + def __setitem__(self, key: str, value: Any): + # allows block_state["foo"] = "bar" + setattr(self, key, value) + + def as_dict(self): + """ + Convert BlockState to a dictionary. + + Returns: + Dict[str, Any]: Dictionary containing all attributes of the BlockState + """ + return {key: value for key, value in self.__dict__.items()} + + def __repr__(self): + def format_value(v): + # Handle tensors directly + if hasattr(v, "shape") and hasattr(v, "dtype"): + return f"Tensor(dtype={v.dtype}, shape={v.shape})" + + # Handle lists of tensors + elif isinstance(v, list): + if len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): + shapes = [t.shape for t in v] + return f"List[{len(v)}] of Tensors with shapes {shapes}" + return repr(v) + + # Handle tuples of tensors + elif isinstance(v, tuple): + if len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): + shapes = [t.shape for t in v] + return f"Tuple[{len(v)}] of Tensors with shapes {shapes}" + return repr(v) + + # Handle dicts with tensor values + elif isinstance(v, dict): + formatted_dict = {} + for k, val in v.items(): + if hasattr(val, "shape") and hasattr(val, "dtype"): + formatted_dict[k] = f"Tensor(shape={val.shape}, dtype={val.dtype})" + elif isinstance(val, list) and len(val) > 0 and hasattr(val[0], "shape") and hasattr(val[0], "dtype"): + shapes = [t.shape for t in val] + formatted_dict[k] = f"List[{len(val)}] of Tensors with shapes {shapes}" + else: + formatted_dict[k] = repr(val) + return formatted_dict + + # Default case + return repr(v) + + attributes = "\n".join(f" {k}: {format_value(v)}" for k, v in self.__dict__.items()) + return f"BlockState(\n{attributes}\n)" + + + +class ModularPipelineMixin: + """ + Mixin for all PipelineBlocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks + """ + + + def setup_loader(self, modular_repo: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None): + """ + create a mouldar loader, optionally accept modular_repo to load from hub. + """ + + # Import components loader (it is model-specific class) + loader_class_name = MODULAR_LOADER_MAPPING[self.model_name] + diffusers_module = importlib.import_module("diffusers") + loader_class = getattr(diffusers_module, loader_class_name) + + # Create deep copies to avoid modifying the original specs + component_specs = deepcopy(self.expected_components) + config_specs = deepcopy(self.expected_configs) + # Create the loader with the updated specs + specs = component_specs + config_specs + + self.loader = loader_class(specs, modular_repo=modular_repo, component_manager=component_manager, collection=collection) + + + @property + def default_call_parameters(self) -> Dict[str, Any]: + params = {} + for input_param in self.inputs: + params[input_param.name] = input_param.default + return params + + def run(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs): + """ + Run one or more blocks in sequence, optionally you can pass a previous pipeline state. + """ + if state is None: + state = PipelineState() + + if not hasattr(self, "loader"): + logger.warning("Loader is not set, please call `setup_loader()` if you need to load checkpoints for your pipeline.") + self.loader = None + + # Make a copy of the input kwargs + passed_kwargs = kwargs.copy() + + + # Add inputs to state, using defaults if not provided in the kwargs or the state + # if same input already in the state, will override it if provided in the kwargs + + intermediates_inputs = [inp.name for inp in self.intermediates_inputs] + for expected_input_param in self.inputs: + name = expected_input_param.name + default = expected_input_param.default + kwargs_type = expected_input_param.kwargs_type + if name in passed_kwargs: + if name not in intermediates_inputs: + state.add_input(name, passed_kwargs.pop(name), kwargs_type) + else: + state.add_input(name, passed_kwargs[name], kwargs_type) + elif name not in state.inputs: + state.add_input(name, default, kwargs_type) + + for expected_intermediate_param in self.intermediates_inputs: + name = expected_intermediate_param.name + kwargs_type = expected_intermediate_param.kwargs_type + if name in passed_kwargs: + state.add_intermediate(name, passed_kwargs.pop(name), kwargs_type) + + # Warn about unexpected inputs + if len(passed_kwargs) > 0: + logger.warning(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.") + # Run the pipeline + with torch.no_grad(): + try: + pipeline, state = self(self.loader, state) + except Exception: + error_msg = f"Error in block: ({self.__class__.__name__}):\n" + logger.error(error_msg) + raise + + if output is None: + return state + + + elif isinstance(output, str): + return state.get_intermediate(output) + + elif isinstance(output, (list, tuple)): + return state.get_intermediates(output) + else: + raise ValueError(f"Output '{output}' is not a valid output type") + + @torch.compiler.disable + def progress_bar(self, iterable=None, total=None): + if not hasattr(self, "_progress_bar_config"): + self._progress_bar_config = {} + elif not isinstance(self._progress_bar_config, dict): + raise ValueError( + f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." + ) + + if iterable is not None: + return tqdm(iterable, **self._progress_bar_config) + elif total is not None: + return tqdm(total=total, **self._progress_bar_config) + else: + raise ValueError("Either `total` or `iterable` has to be defined.") + + def set_progress_bar_config(self, **kwargs): + self._progress_bar_config = kwargs + + +class PipelineBlock(ModularPipelineMixin): + + model_name = None + + @property + def description(self) -> str: + """Description of the block. Must be implemented by subclasses.""" + raise NotImplementedError("description method must be implemented in subclasses") + + @property + def expected_components(self) -> List[ComponentSpec]: + return [] + + @property + def expected_configs(self) -> List[ConfigSpec]: + return [] + + + # YiYi TODO: can we combine inputs and intermediates_inputs? the difference is inputs are immutable + @property + def inputs(self) -> List[InputParam]: + """List of input parameters. Must be implemented by subclasses.""" + return [] + + @property + def intermediates_inputs(self) -> List[InputParam]: + """List of intermediate input parameters. Must be implemented by subclasses.""" + return [] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + """List of intermediate output parameters. Must be implemented by subclasses.""" + return [] + + # Adding outputs attributes here for consistency between PipelineBlock/AutoPipelineBlocks/SequentialPipelineBlocks + @property + def outputs(self) -> List[OutputParam]: + return self.intermediates_outputs + + @property + def required_inputs(self) -> List[str]: + input_names = [] + for input_param in self.inputs: + if input_param.required: + input_names.append(input_param.name) + return input_names + + @property + def required_intermediates_inputs(self) -> List[str]: + input_names = [] + for input_param in self.intermediates_inputs: + if input_param.required: + input_names.append(input_param.name) + return input_names + + + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + raise NotImplementedError("__call__ method must be implemented in subclasses") + + def __repr__(self): + class_name = self.__class__.__name__ + base_class = self.__class__.__bases__[0].__name__ + + # Format description with proper indentation + desc_lines = self.description.split('\n') + desc = [] + # First line with "Description:" label + desc.append(f" Description: {desc_lines[0]}") + # Subsequent lines with proper indentation + if len(desc_lines) > 1: + desc.extend(f" {line}" for line in desc_lines[1:]) + desc = '\n'.join(desc) + '\n' + + # Components section - use format_components with add_empty_lines=False + expected_components = getattr(self, "expected_components", []) + components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) + components = " " + components_str.replace("\n", "\n ") + + # Configs section - use format_configs with add_empty_lines=False + expected_configs = getattr(self, "expected_configs", []) + configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) + configs = " " + configs_str.replace("\n", "\n ") + + # Inputs section + inputs_str = format_inputs_short(self.inputs) + inputs = "Inputs:\n " + inputs_str + + # Intermediates section + intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs) + intermediates = f"Intermediates:\n{intermediates_str}" + + return ( + f"{class_name}(\n" + f" Class: {base_class}\n" + f"{desc}" + f"{components}\n" + f"{configs}\n" + f" {inputs}\n" + f" {intermediates}\n" + f")" + ) + + + @property + def doc(self): + return make_doc_string( + self.inputs, + self.intermediates_inputs, + self.outputs, + self.description, + class_name=self.__class__.__name__, + expected_components=self.expected_components, + expected_configs=self.expected_configs + ) + + + def get_block_state(self, state: PipelineState) -> dict: + """Get all inputs and intermediates in one dictionary""" + data = {} + + # Check inputs + for input_param in self.inputs: + if input_param.name: + value = state.get_input(input_param.name) + if input_param.required and value is None: + raise ValueError(f"Required input '{input_param.name}' is missing") + elif value is not None or (value is None and input_param.name not in data): + data[input_param.name] = value + elif input_param.kwargs_type: + # if kwargs_type is provided, get all inputs with matching kwargs_type + if input_param.kwargs_type not in data: + data[input_param.kwargs_type] = {} + inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type) + if inputs_kwargs: + for k, v in inputs_kwargs.items(): + if v is not None: + data[k] = v + data[input_param.kwargs_type][k] = v + + # Check intermediates + for input_param in self.intermediates_inputs: + if input_param.name: + value = state.get_intermediate(input_param.name) + if input_param.required and value is None: + raise ValueError(f"Required intermediate input '{input_param.name}' is missing") + elif value is not None or (value is None and input_param.name not in data): + data[input_param.name] = value + elif input_param.kwargs_type: + # if kwargs_type is provided, get all intermediates with matching kwargs_type + if input_param.kwargs_type not in data: + data[input_param.kwargs_type] = {} + intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type) + if intermediates_kwargs: + for k, v in intermediates_kwargs.items(): + if v is not None: + if k not in data: + data[k] = v + data[input_param.kwargs_type][k] = v + return BlockState(**data) + + def add_block_state(self, state: PipelineState, block_state: BlockState): + for output_param in self.intermediates_outputs: + if not hasattr(block_state, output_param.name): + raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") + param = getattr(block_state, output_param.name) + state.add_intermediate(output_param.name, param, output_param.kwargs_type) + + +def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]: + """ + Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if + current default value is None and new default value is not None. Warns if multiple non-None default values + exist for the same input. + + Args: + named_input_lists: List of tuples containing (block_name, input_param_list) pairs + + Returns: + List[InputParam]: Combined list of unique InputParam objects + """ + combined_dict = {} # name -> InputParam + value_sources = {} # name -> block_name + + for block_name, inputs in named_input_lists: + for input_param in inputs: + if input_param.name is None and input_param.kwargs_type is not None: + input_name = "*_" + input_param.kwargs_type + else: + input_name = input_param.name + if input_name in combined_dict: + current_param = combined_dict[input_name] + if (current_param.default is not None and + input_param.default is not None and + current_param.default != input_param.default): + warnings.warn( + f"Multiple different default values found for input '{input_param.name}': " + f"{current_param.default} (from block '{value_sources[input_param.name]}') and " + f"{input_param.default} (from block '{block_name}'). Using {current_param.default}." + ) + if current_param.default is None and input_param.default is not None: + combined_dict[input_param.name] = input_param + value_sources[input_param.name] = block_name + else: + combined_dict[input_param.name] = input_param + value_sources[input_param.name] = block_name + + return list(combined_dict.values()) + +def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> List[OutputParam]: + """ + Combines multiple lists of OutputParam objects from different blocks. For duplicate outputs, + keeps the first occurrence of each output name. + + Args: + named_output_lists: List of tuples containing (block_name, output_param_list) pairs + + Returns: + List[OutputParam]: Combined list of unique OutputParam objects + """ + combined_dict = {} # name -> OutputParam + + for block_name, outputs in named_output_lists: + for output_param in outputs: + if (output_param.name not in combined_dict) or (combined_dict[output_param.name].kwargs_type is None and output_param.kwargs_type is not None): + combined_dict[output_param.name] = output_param + + return list(combined_dict.values()) + + +class AutoPipelineBlocks(ModularPipelineMixin): + """ + A class that automatically selects a block to run based on the inputs. + + Attributes: + block_classes: List of block classes to be used + block_names: List of prefixes for each block + block_trigger_inputs: List of input names that trigger specific blocks, with None for default + """ + + block_classes = [] + block_names = [] + block_trigger_inputs = [] + + def __init__(self): + blocks = OrderedDict() + for block_name, block_cls in zip(self.block_names, self.block_classes): + blocks[block_name] = block_cls() + self.blocks = blocks + if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)): + raise ValueError(f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same.") + default_blocks = [t for t in self.block_trigger_inputs if t is None] + # can only have 1 or 0 default block, and has to put in the last + # the order of blocksmatters here because the first block with matching trigger will be dispatched + # e.g. blocks = [inpaint, img2img] and block_trigger_inputs = ["mask", "image"] + # if both mask and image are provided, it is inpaint; if only image is provided, it is img2img + if len(default_blocks) > 1 or ( + len(default_blocks) == 1 and self.block_trigger_inputs[-1] is not None + ): + raise ValueError( + f"In {self.__class__.__name__}, exactly one None must be specified as the last element " + "in block_trigger_inputs." + ) + + # Map trigger inputs to block objects + self.trigger_to_block_map = dict(zip(self.block_trigger_inputs, self.blocks.values())) + self.trigger_to_block_name_map = dict(zip(self.block_trigger_inputs, self.blocks.keys())) + self.block_to_trigger_map = dict(zip(self.blocks.keys(), self.block_trigger_inputs)) + + @property + def model_name(self): + return next(iter(self.blocks.values())).model_name + + @property + def description(self): + return "" + + @property + def expected_components(self): + expected_components = [] + for block in self.blocks.values(): + for component in block.expected_components: + if component not in expected_components: + expected_components.append(component) + return expected_components + + @property + def expected_configs(self): + expected_configs = [] + for block in self.blocks.values(): + for config in block.expected_configs: + if config not in expected_configs: + expected_configs.append(config) + return expected_configs + + + @property + def required_inputs(self) -> List[str]: + first_block = next(iter(self.blocks.values())) + required_by_all = set(getattr(first_block, "required_inputs", set())) + + # Intersect with required inputs from all other blocks + for block in list(self.blocks.values())[1:]: + block_required = set(getattr(block, "required_inputs", set())) + required_by_all.intersection_update(block_required) + + return list(required_by_all) + + @property + def required_intermediates_inputs(self) -> List[str]: + first_block = next(iter(self.blocks.values())) + required_by_all = set(getattr(first_block, "required_intermediates_inputs", set())) + + # Intersect with required inputs from all other blocks + for block in list(self.blocks.values())[1:]: + block_required = set(getattr(block, "required_intermediates_inputs", set())) + required_by_all.intersection_update(block_required) + + return list(required_by_all) + + + # YiYi TODO: add test for this + @property + def inputs(self) -> List[Tuple[str, Any]]: + named_inputs = [(name, block.inputs) for name, block in self.blocks.items()] + combined_inputs = combine_inputs(*named_inputs) + # mark Required inputs only if that input is required by all the blocks + for input_param in combined_inputs: + if input_param.name in self.required_inputs: + input_param.required = True + else: + input_param.required = False + return combined_inputs + + + @property + def intermediates_inputs(self) -> List[str]: + named_inputs = [(name, block.intermediates_inputs) for name, block in self.blocks.items()] + combined_inputs = combine_inputs(*named_inputs) + # mark Required inputs only if that input is required by all the blocks + for input_param in combined_inputs: + if input_param.name in self.required_intermediates_inputs: + input_param.required = True + else: + input_param.required = False + return combined_inputs + + @property + def intermediates_outputs(self) -> List[str]: + named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] + combined_outputs = combine_outputs(*named_outputs) + return combined_outputs + + @property + def outputs(self) -> List[str]: + named_outputs = [(name, block.outputs) for name, block in self.blocks.items()] + combined_outputs = combine_outputs(*named_outputs) + return combined_outputs + + @torch.no_grad() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + # Find default block first (if any) + + block = self.trigger_to_block_map.get(None) + for input_name in self.block_trigger_inputs: + if input_name is not None and state.get_input(input_name) is not None: + block = self.trigger_to_block_map[input_name] + break + elif input_name is not None and state.get_intermediate(input_name) is not None: + block = self.trigger_to_block_map[input_name] + break + + if block is None: + logger.warning(f"skipping auto block: {self.__class__.__name__}") + return pipeline, state + + try: + logger.info(f"Running block: {block.__class__.__name__}, trigger: {input_name}") + return block(pipeline, state) + except Exception as e: + error_msg = ( + f"\nError in block: {block.__class__.__name__}\n" + f"Error details: {str(e)}\n" + f"Traceback:\n{traceback.format_exc()}" + ) + logger.error(error_msg) + raise + + def _get_trigger_inputs(self): + """ + Returns a set of all unique trigger input values found in the blocks. + Returns: Set[str] containing all unique block_trigger_inputs values + """ + def fn_recursive_get_trigger(blocks): + trigger_values = set() + + if blocks is not None: + for name, block in blocks.items(): + # Check if current block has trigger inputs(i.e. auto block) + if hasattr(block, 'block_trigger_inputs') and block.block_trigger_inputs is not None: + # Add all non-None values from the trigger inputs list + trigger_values.update(t for t in block.block_trigger_inputs if t is not None) + + # If block has blocks, recursively check them + if hasattr(block, 'blocks'): + nested_triggers = fn_recursive_get_trigger(block.blocks) + trigger_values.update(nested_triggers) + + return trigger_values + + trigger_inputs = set(self.block_trigger_inputs) + trigger_inputs.update(fn_recursive_get_trigger(self.blocks)) + + return trigger_inputs + + @property + def trigger_inputs(self): + return self._get_trigger_inputs() + + def __repr__(self): + class_name = self.__class__.__name__ + base_class = self.__class__.__bases__[0].__name__ + header = ( + f"{class_name}(\n Class: {base_class}\n" + if base_class and base_class != "object" + else f"{class_name}(\n" + ) + + + if self.trigger_inputs: + header += "\n" + header += " " + "=" * 100 + "\n" + header += " This pipeline contains blocks that are selected at runtime based on inputs.\n" + header += f" Trigger Inputs: {self.trigger_inputs}\n" + # Get first trigger input as example + example_input = next(t for t in self.trigger_inputs if t is not None) + header += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n" + header += " " + "=" * 100 + "\n\n" + + # Format description with proper indentation + desc_lines = self.description.split('\n') + desc = [] + # First line with "Description:" label + desc.append(f" Description: {desc_lines[0]}") + # Subsequent lines with proper indentation + if len(desc_lines) > 1: + desc.extend(f" {line}" for line in desc_lines[1:]) + desc = '\n'.join(desc) + '\n' + + # Components section - focus only on expected components + expected_components = getattr(self, "expected_components", []) + components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) + + # Configs section - use format_configs with add_empty_lines=False + expected_configs = getattr(self, "expected_configs", []) + configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) + + # Blocks section - moved to the end with simplified format + blocks_str = " Blocks:\n" + for i, (name, block) in enumerate(self.blocks.items()): + # Get trigger input for this block + trigger = None + if hasattr(self, 'block_to_trigger_map'): + trigger = self.block_to_trigger_map.get(name) + # Format the trigger info + if trigger is None: + trigger_str = "[default]" + elif isinstance(trigger, (list, tuple)): + trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]" + else: + trigger_str = f"[trigger: {trigger}]" + # For AutoPipelineBlocks, add bullet points + blocks_str += f" • {name} {trigger_str} ({block.__class__.__name__})\n" + else: + # For SequentialPipelineBlocks, show execution order + blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" + + # Add block description + desc_lines = block.description.split('\n') + indented_desc = desc_lines[0] + if len(desc_lines) > 1: + indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) + blocks_str += f" Description: {indented_desc}\n\n" + + return ( + f"{header}\n" + f"{desc}\n\n" + f"{components_str}\n\n" + f"{configs_str}\n\n" + f"{blocks_str}" + f")" + ) + + + @property + def doc(self): + return make_doc_string( + self.inputs, + self.intermediates_inputs, + self.outputs, + self.description, + class_name=self.__class__.__name__, + expected_components=self.expected_components, + expected_configs=self.expected_configs + ) + +class SequentialPipelineBlocks(ModularPipelineMixin): + """ + A class that combines multiple pipeline block classes into one. When called, it will call each block in sequence. + """ + block_classes = [] + block_names = [] + + @property + def model_name(self): + return next(iter(self.blocks.values())).model_name + + @property + def description(self): + return "" + + @property + def expected_components(self): + expected_components = [] + for block in self.blocks.values(): + for component in block.expected_components: + if component not in expected_components: + expected_components.append(component) + return expected_components + + @property + def expected_configs(self): + expected_configs = [] + for block in self.blocks.values(): + for config in block.expected_configs: + if config not in expected_configs: + expected_configs.append(config) + return expected_configs + + @classmethod + def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "SequentialPipelineBlocks": + """Creates a SequentialPipelineBlocks instance from a dictionary of blocks. + + Args: + blocks_dict: Dictionary mapping block names to block instances + + Returns: + A new SequentialPipelineBlocks instance + """ + instance = cls() + instance.block_classes = [block.__class__ for block in blocks_dict.values()] + instance.block_names = list(blocks_dict.keys()) + instance.blocks = blocks_dict + return instance + + def __init__(self): + blocks = OrderedDict() + for block_name, block_cls in zip(self.block_names, self.block_classes): + blocks[block_name] = block_cls() + self.blocks = blocks + + + @property + def required_inputs(self) -> List[str]: + # Get the first block from the dictionary + first_block = next(iter(self.blocks.values())) + required_by_any = set(getattr(first_block, "required_inputs", set())) + + # Union with required inputs from all other blocks + for block in list(self.blocks.values())[1:]: + block_required = set(getattr(block, "required_inputs", set())) + required_by_any.update(block_required) + + return list(required_by_any) + + @property + def required_intermediates_inputs(self) -> List[str]: + required_intermediates_inputs = [] + for input_param in self.intermediates_inputs: + if input_param.required: + required_intermediates_inputs.append(input_param.name) + return required_intermediates_inputs + + # YiYi TODO: add test for this + @property + def inputs(self) -> List[Tuple[str, Any]]: + return self.get_inputs() + + def get_inputs(self): + named_inputs = [(name, block.inputs) for name, block in self.blocks.items()] + combined_inputs = combine_inputs(*named_inputs) + # mark Required inputs only if that input is required any of the blocks + for input_param in combined_inputs: + if input_param.name in self.required_inputs: + input_param.required = True + else: + input_param.required = False + return combined_inputs + + @property + def intermediates_inputs(self) -> List[str]: + return self.get_intermediates_inputs() + + def get_intermediates_inputs(self): + inputs = [] + outputs = set() + + # Go through all blocks in order + for block in self.blocks.values(): + # Add inputs that aren't in outputs yet + inputs.extend(input_name for input_name in block.intermediates_inputs if input_name.name not in outputs) + + # Only add outputs if the block cannot be skipped + should_add_outputs = True + if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs: + should_add_outputs = False + + if should_add_outputs: + # Add this block's outputs + block_intermediates_outputs = [out.name for out in block.intermediates_outputs] + outputs.update(block_intermediates_outputs) + return inputs + + @property + def intermediates_outputs(self) -> List[str]: + named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] + combined_outputs = combine_outputs(*named_outputs) + return combined_outputs + + @property + def outputs(self) -> List[str]: + return next(reversed(self.blocks.values())).intermediates_outputs + + @torch.no_grad() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + for block_name, block in self.blocks.items(): + try: + pipeline, state = block(pipeline, state) + except Exception as e: + error_msg = ( + f"\nError in block: ({block_name}, {block.__class__.__name__})\n" + f"Error details: {str(e)}\n" + f"Traceback:\n{traceback.format_exc()}" + ) + logger.error(error_msg) + raise + return pipeline, state + + def _get_trigger_inputs(self): + """ + Returns a set of all unique trigger input values found in the blocks. + Returns: Set[str] containing all unique block_trigger_inputs values + """ + def fn_recursive_get_trigger(blocks): + trigger_values = set() + + if blocks is not None: + for name, block in blocks.items(): + # Check if current block has trigger inputs(i.e. auto block) + if hasattr(block, 'block_trigger_inputs') and block.block_trigger_inputs is not None: + # Add all non-None values from the trigger inputs list + trigger_values.update(t for t in block.block_trigger_inputs if t is not None) + + # If block has blocks, recursively check them + if hasattr(block, 'blocks'): + nested_triggers = fn_recursive_get_trigger(block.blocks) + trigger_values.update(nested_triggers) + + return trigger_values + + return fn_recursive_get_trigger(self.blocks) + + @property + def trigger_inputs(self): + return self._get_trigger_inputs() + + def _traverse_trigger_blocks(self, trigger_inputs): + # Convert trigger_inputs to a set for easier manipulation + active_triggers = set(trigger_inputs) + def fn_recursive_traverse(block, block_name, active_triggers): + result_blocks = OrderedDict() + + # sequential(include loopsequential) or PipelineBlock + if not hasattr(block, 'block_trigger_inputs'): + if hasattr(block, 'blocks'): + # sequential or LoopSequentialPipelineBlocks (keep traversing) + for sub_block_name, sub_block in block.blocks.items(): + blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers) + blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers) + blocks_to_update = {f"{block_name}.{k}": v for k,v in blocks_to_update.items()} + result_blocks.update(blocks_to_update) + else: + # PipelineBlock + result_blocks[block_name] = block + # Add this block's output names to active triggers if defined + if hasattr(block, 'outputs'): + active_triggers.update(out.name for out in block.outputs) + return result_blocks + + # auto + else: + # Find first block_trigger_input that matches any value in our active_triggers + this_block = None + matching_trigger = None + for trigger_input in block.block_trigger_inputs: + if trigger_input is not None and trigger_input in active_triggers: + this_block = block.trigger_to_block_map[trigger_input] + matching_trigger = trigger_input + break + + # If no matches found, try to get the default (None) block + if this_block is None and None in block.block_trigger_inputs: + this_block = block.trigger_to_block_map[None] + matching_trigger = None + + if this_block is not None: + # sequential/auto (keep traversing) + if hasattr(this_block, 'blocks'): + result_blocks.update(fn_recursive_traverse(this_block, block_name, active_triggers)) + else: + # PipelineBlock + result_blocks[block_name] = this_block + # Add this block's output names to active triggers if defined + # YiYi TODO: do we need outputs here? can it just be intermediate_outputs? can we get rid of outputs attribute? + if hasattr(this_block, 'outputs'): + active_triggers.update(out.name for out in this_block.outputs) + + return result_blocks + + all_blocks = OrderedDict() + for block_name, block in self.blocks.items(): + blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers) + all_blocks.update(blocks_to_update) + return all_blocks + + def get_execution_blocks(self, *trigger_inputs): + trigger_inputs_all = self.trigger_inputs + + if trigger_inputs is not None: + + if not isinstance(trigger_inputs, (list, tuple, set)): + trigger_inputs = [trigger_inputs] + invalid_inputs = [x for x in trigger_inputs if x not in trigger_inputs_all] + if invalid_inputs: + logger.warning( + f"The following trigger inputs will be ignored as they are not supported: {invalid_inputs}" + ) + trigger_inputs = [x for x in trigger_inputs if x in trigger_inputs_all] + + if trigger_inputs is None: + if None in trigger_inputs_all: + trigger_inputs = [None] + else: + trigger_inputs = [trigger_inputs_all[0]] + blocks_triggered = self._traverse_trigger_blocks(trigger_inputs) + return SequentialPipelineBlocks.from_blocks_dict(blocks_triggered) + + def __repr__(self): + class_name = self.__class__.__name__ + base_class = self.__class__.__bases__[0].__name__ + header = ( + f"{class_name}(\n Class: {base_class}\n" + if base_class and base_class != "object" + else f"{class_name}(\n" + ) + + + if self.trigger_inputs: + header += "\n" + header += " " + "=" * 100 + "\n" + header += " This pipeline contains blocks that are selected at runtime based on inputs.\n" + header += f" Trigger Inputs: {self.trigger_inputs}\n" + # Get first trigger input as example + example_input = next(t for t in self.trigger_inputs if t is not None) + header += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n" + header += " " + "=" * 100 + "\n\n" + + # Format description with proper indentation + desc_lines = self.description.split('\n') + desc = [] + # First line with "Description:" label + desc.append(f" Description: {desc_lines[0]}") + # Subsequent lines with proper indentation + if len(desc_lines) > 1: + desc.extend(f" {line}" for line in desc_lines[1:]) + desc = '\n'.join(desc) + '\n' + + # Components section - focus only on expected components + expected_components = getattr(self, "expected_components", []) + components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) + + # Configs section - use format_configs with add_empty_lines=False + expected_configs = getattr(self, "expected_configs", []) + configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) + + # Blocks section - moved to the end with simplified format + blocks_str = " Blocks:\n" + for i, (name, block) in enumerate(self.blocks.items()): + # Get trigger input for this block + trigger = None + if hasattr(self, 'block_to_trigger_map'): + trigger = self.block_to_trigger_map.get(name) + # Format the trigger info + if trigger is None: + trigger_str = "[default]" + elif isinstance(trigger, (list, tuple)): + trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]" + else: + trigger_str = f"[trigger: {trigger}]" + # For AutoPipelineBlocks, add bullet points + blocks_str += f" • {name} {trigger_str} ({block.__class__.__name__})\n" + else: + # For SequentialPipelineBlocks, show execution order + blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" + + # Add block description + desc_lines = block.description.split('\n') + indented_desc = desc_lines[0] + if len(desc_lines) > 1: + indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) + blocks_str += f" Description: {indented_desc}\n\n" + + return ( + f"{header}\n" + f"{desc}\n\n" + f"{components_str}\n\n" + f"{configs_str}\n\n" + f"{blocks_str}" + f")" + ) + + + @property + def doc(self): + return make_doc_string( + self.inputs, + self.intermediates_inputs, + self.outputs, + self.description, + class_name=self.__class__.__name__, + expected_components=self.expected_components, + expected_configs=self.expected_configs + ) + +#YiYi TODO: __repr__ +class LoopSequentialPipelineBlocks(ModularPipelineMixin): + """ + A class that combines multiple pipeline block classes into a For Loop. When called, it will call each block in sequence. + """ + + model_name = None + block_classes = [] + block_names = [] + + @property + def description(self) -> str: + """Description of the block. Must be implemented by subclasses.""" + raise NotImplementedError("description method must be implemented in subclasses") + + @property + def loop_expected_components(self) -> List[ComponentSpec]: + return [] + + @property + def loop_expected_configs(self) -> List[ConfigSpec]: + return [] + + @property + def loop_inputs(self) -> List[InputParam]: + """List of input parameters. Must be implemented by subclasses.""" + return [] + + @property + def loop_intermediates_inputs(self) -> List[InputParam]: + """List of intermediate input parameters. Must be implemented by subclasses.""" + return [] + + @property + def loop_intermediates_outputs(self) -> List[OutputParam]: + """List of intermediate output parameters. Must be implemented by subclasses.""" + return [] + + + @property + def loop_required_inputs(self) -> List[str]: + input_names = [] + for input_param in self.loop_inputs: + if input_param.required: + input_names.append(input_param.name) + return input_names + + @property + def loop_required_intermediates_inputs(self) -> List[str]: + input_names = [] + for input_param in self.loop_intermediates_inputs: + if input_param.required: + input_names.append(input_param.name) + return input_names + + # modified from SequentialPipelineBlocks to include loop_expected_components + @property + def expected_components(self): + expected_components = [] + for block in self.blocks.values(): + for component in block.expected_components: + if component not in expected_components: + expected_components.append(component) + for component in self.loop_expected_components: + if component not in expected_components: + expected_components.append(component) + return expected_components + + # modified from SequentialPipelineBlocks to include loop_expected_configs + @property + def expected_configs(self): + expected_configs = [] + for block in self.blocks.values(): + for config in block.expected_configs: + if config not in expected_configs: + expected_configs.append(config) + for config in self.loop_expected_configs: + if config not in expected_configs: + expected_configs.append(config) + return expected_configs + + # modified from SequentialPipelineBlocks to include loop_inputs + def get_inputs(self): + named_inputs = [(name, block.inputs) for name, block in self.blocks.items()] + named_inputs.append(("loop", self.loop_inputs)) + combined_inputs = combine_inputs(*named_inputs) + # mark Required inputs only if that input is required any of the blocks + for input_param in combined_inputs: + if input_param.name in self.required_inputs: + input_param.required = True + else: + input_param.required = False + return combined_inputs + + # Copied from SequentialPipelineBlocks + @property + def inputs(self): + return self.get_inputs() + + + # modified from SequentialPipelineBlocks to include loop_intermediates_inputs + @property + def intermediates_inputs(self): + intermediates = self.get_intermediates_inputs() + intermediate_names = [input.name for input in intermediates] + for loop_intermediate_input in self.loop_intermediates_inputs: + if loop_intermediate_input.name not in intermediate_names: + intermediates.append(loop_intermediate_input) + return intermediates + + + # Copied from SequentialPipelineBlocks + def get_intermediates_inputs(self): + inputs = [] + outputs = set() + + # Go through all blocks in order + for block in self.blocks.values(): + # Add inputs that aren't in outputs yet + inputs.extend(input_name for input_name in block.intermediates_inputs if input_name.name not in outputs) + + # Only add outputs if the block cannot be skipped + should_add_outputs = True + if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs: + should_add_outputs = False + + if should_add_outputs: + # Add this block's outputs + block_intermediates_outputs = [out.name for out in block.intermediates_outputs] + outputs.update(block_intermediates_outputs) + return inputs + + + # modified from SequentialPipelineBlocks, if any additionan input required by the loop is required by the block + @property + def required_inputs(self) -> List[str]: + # Get the first block from the dictionary + first_block = next(iter(self.blocks.values())) + required_by_any = set(getattr(first_block, "required_inputs", set())) + + required_by_loop = set(getattr(self, "loop_required_inputs", set())) + required_by_any.update(required_by_loop) + + # Union with required inputs from all other blocks + for block in list(self.blocks.values())[1:]: + block_required = set(getattr(block, "required_inputs", set())) + required_by_any.update(block_required) + + return list(required_by_any) + + # modified from SequentialPipelineBlocks, if any additional intermediate input required by the loop is required by the block + @property + def required_intermediates_inputs(self) -> List[str]: + required_intermediates_inputs = [] + for input_param in self.intermediates_inputs: + if input_param.required: + required_intermediates_inputs.append(input_param.name) + for input_param in self.loop_intermediates_inputs: + if input_param.required: + required_intermediates_inputs.append(input_param.name) + return required_intermediates_inputs + + + # YiYi TODO: this need to be thought about more + # modified from SequentialPipelineBlocks to include loop_intermediates_outputs + @property + def intermediates_outputs(self) -> List[str]: + named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] + combined_outputs = combine_outputs(*named_outputs) + for output in self.loop_intermediates_outputs: + if output.name not in set([output.name for output in combined_outputs]): + combined_outputs.append(output) + return combined_outputs + + # YiYi TODO: this need to be thought about more + # copied from SequentialPipelineBlocks + @property + def outputs(self) -> List[str]: + return next(reversed(self.blocks.values())).intermediates_outputs + + + def __init__(self): + blocks = OrderedDict() + for block_name, block_cls in zip(self.block_names, self.block_classes): + blocks[block_name] = block_cls() + self.blocks = blocks + + def loop_step(self, components, state: PipelineState, **kwargs): + + for block_name, block in self.blocks.items(): + try: + components, state = block(components, state, **kwargs) + except Exception as e: + error_msg = ( + f"\nError in block: ({block_name}, {block.__class__.__name__})\n" + f"Error details: {str(e)}\n" + f"Traceback:\n{traceback.format_exc()}" + ) + logger.error(error_msg) + raise + return components, state + + def __call__(self, components, state: PipelineState) -> PipelineState: + raise NotImplementedError("`__call__` method needs to be implemented by the subclass") + + + def get_block_state(self, state: PipelineState) -> dict: + """Get all inputs and intermediates in one dictionary""" + data = {} + + # Check inputs + for input_param in self.inputs: + if input_param.name: + value = state.get_input(input_param.name) + if input_param.required and value is None: + raise ValueError(f"Required input '{input_param.name}' is missing") + elif value is not None or (value is None and input_param.name not in data): + data[input_param.name] = value + elif input_param.kwargs_type: + # if kwargs_type is provided, get all inputs with matching kwargs_type + if input_param.kwargs_type not in data: + data[input_param.kwargs_type] = {} + inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type) + if inputs_kwargs: + for k, v in inputs_kwargs.items(): + if v is not None: + data[k] = v + data[input_param.kwargs_type][k] = v + + # Check intermediates + for input_param in self.intermediates_inputs: + if input_param.name: + value = state.get_intermediate(input_param.name) + if input_param.required and value is None: + raise ValueError(f"Required intermediate input '{input_param.name}' is missing") + elif value is not None or (value is None and input_param.name not in data): + data[input_param.name] = value + elif input_param.kwargs_type: + # if kwargs_type is provided, get all intermediates with matching kwargs_type + if input_param.kwargs_type not in data: + data[input_param.kwargs_type] = {} + intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type) + if intermediates_kwargs: + for k, v in intermediates_kwargs.items(): + if v is not None: + if k not in data: + data[k] = v + data[input_param.kwargs_type][k] = v + return BlockState(**data) + + def add_block_state(self, state: PipelineState, block_state: BlockState): + for output_param in self.intermediates_outputs: + if not hasattr(block_state, output_param.name): + raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") + param = getattr(block_state, output_param.name) + state.add_intermediate(output_param.name, param, output_param.kwargs_type) + +# YiYi TODO: +# 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess) +# 2. do we need ConfigSpec? seems pretty unnecessrary for loader, can just add and kwargs to the loader +# 3. add validator for methods where we accpet kwargs to be passed to from_pretrained() +class ModularLoader(ConfigMixin, PushToHubMixin): + """ + Base class for all Modular pipelines loaders. + + """ + config_name = "modular_model_index.json" + + + def register_components(self, **kwargs): + """ + Register components with their corresponding specs. + This method is called when component changed or __init__ is called. + + Args: + **kwargs: Keyword arguments where keys are component names and values are component objects. + + """ + for name, module in kwargs.items(): + + # current component spec + component_spec = self._component_specs.get(name) + if component_spec is None: + logger.warning(f"ModularLoader.register_components: skipping unknown component '{name}'") + continue + + is_registered = hasattr(self, name) + + if module is not None and not hasattr(module, "_diffusers_load_id"): + raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") + + # actual library and class name of the module + + if module is not None: + library, class_name = _fetch_class_library_tuple(module) + new_component_spec = ComponentSpec.from_component(name, module) + component_spec_dict = self._component_spec_to_dict(new_component_spec) + + else: + library, class_name = None, None + # if module is None, we do not update the spec, + # but we still need to update the config to make sure it's synced with the component spec + # (in the case of the first time registration, we initilize the object with component spec, and then we call register_components() to register it to config) + new_component_spec = component_spec + component_spec_dict = self._component_spec_to_dict(component_spec) + + # do not register if component is not to be loaded from pretrained + if new_component_spec.default_creation_method == "from_pretrained": + register_dict = {name: (library, class_name, component_spec_dict)} + else: + register_dict = {} + + # set the component as attribute + # if it is not set yet, just set it and skip the process to check and warn below + if not is_registered: + self.register_to_config(**register_dict) + self._component_specs[name] = new_component_spec + setattr(self, name, module) + if module is not None and self._component_manager is not None: + self._component_manager.add(name, module, self._collection) + continue + + current_module = getattr(self, name, None) + # skip if the component is already registered with the same object + if current_module is module: + logger.info(f"ModularLoader.register_components: {name} is already registered with same object, skipping") + continue + + # it module is not an instance of the expected type, still register it but with a warning + if module is not None and component_spec.type_hint is not None and not isinstance(module, component_spec.type_hint): + logger.warning(f"ModularLoader.register_components: adding {name} with new type: {module.__class__.__name__}, previous type: {component_spec.type_hint.__name__}") + + # warn if unregister + if current_module is not None and module is None: + logger.info( + f"ModularLoader.register_components: setting '{name}' to None " + f"(was {current_module.__class__.__name__})" + ) + # same type, new instance → debug + elif current_module is not None \ + and module is not None \ + and isinstance(module, current_module.__class__) \ + and current_module != module: + logger.debug( + f"ModularLoader.register_components: replacing existing '{name}' " + f"(same type {type(current_module).__name__}, new instance)" + ) + + # save modular_model_index.json config + self.register_to_config(**register_dict) + # update component spec + self._component_specs[name] = new_component_spec + # finally set models + setattr(self, name, module) + if module is not None and self._component_manager is not None: + self._component_manager.add(name, module, self._collection) + + + + # YiYi TODO: add warning for passing multiple ComponentSpec/ConfigSpec with the same name + def __init__(self, specs: List[Union[ComponentSpec, ConfigSpec]], modular_repo: Optional[str] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs): + """ + Initialize the loader with a list of component specs and config specs. + """ + self._component_manager = component_manager + self._collection = collection + self._component_specs = { + spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ComponentSpec) + } + self._config_specs = { + spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ConfigSpec) + } + + # update component_specs and config_specs from modular_repo + if modular_repo is not None: + config_dict = self.load_config(modular_repo, **kwargs) + + for name, value in config_dict.items(): + if name in self._component_specs and self._component_specs[name].default_creation_method == "from_pretrained" and isinstance(value, (tuple, list)) and len(value) == 3: + library, class_name, component_spec_dict = value + component_spec = self._dict_to_component_spec(name, component_spec_dict) + self._component_specs[name] = component_spec + + elif name in self._config_specs: + self._config_specs[name].default = value + + register_components_dict = {} + for name, component_spec in self._component_specs.items(): + register_components_dict[name] = None + self.register_components(**register_components_dict) + + default_configs = {} + for name, config_spec in self._config_specs.items(): + default_configs[name] = config_spec.default + self.register_to_config(**default_configs) + + + @property + def device(self) -> torch.device: + r""" + Returns: + `torch.device`: The torch device on which the pipeline is located. + """ + modules = self.components.values() + modules = [m for m in modules if isinstance(m, torch.nn.Module)] + + for module in modules: + return module.device + + return torch.device("cpu") + + @property + # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + [`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from + Accelerate's module hooks. + """ + for name, model in self.components.items(): + if not isinstance(model, torch.nn.Module): + continue + + if not hasattr(model, "_hf_hook"): + return self.device + for module in model.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + @property + def device(self) -> torch.device: + r""" + Returns: + `torch.device`: The torch device on which the pipeline is located. + """ + + modules = [m for m in self.components.values() if isinstance(m, torch.nn.Module)] + + for module in modules: + return module.device + + return torch.device("cpu") + + @property + def dtype(self) -> torch.dtype: + r""" + Returns: + `torch.dtype`: The torch dtype on which the pipeline is located. + """ + modules = self.components.values() + modules = [m for m in modules if isinstance(m, torch.nn.Module)] + + for module in modules: + return module.dtype + + return torch.float32 + + + @property + def components(self) -> Dict[str, Any]: + # return only components we've actually set as attributes on self + return { + name: getattr(self, name) + for name in self._component_specs.keys() + if hasattr(self, name) + } + + def update(self, **kwargs): + """ + Update components and configs after instance creation. + + Args: + + """ + """ + Update components and configuration values after the loader has been instantiated. + + This method allows you to: + 1. Replace existing components with new ones (e.g., updating the unet or text_encoder) + 2. Update configuration values (e.g., changing requires_safety_checker flag) + + Args: + **kwargs: Component objects or configuration values to update: + - Component objects: Must be created using ComponentSpec (e.g., `unet=new_unet, text_encoder=new_encoder`) + - Configuration values: Simple values to update configuration settings (e.g., `requires_safety_checker=False`) + + Raises: + ValueError: If a component wasn't created using ComponentSpec (doesn't have `_diffusers_load_id` attribute) + + Examples: + ```python + # Update multiple components at once + loader.update( + unet=new_unet_model, + text_encoder=new_text_encoder + ) + + # Update configuration values + loader.update( + requires_safety_checker=False, + guidance_rescale=0.7 + ) + + # Update both components and configs together + loader.update( + unet=new_unet_model, + requires_safety_checker=False + ) + ``` + """ + + # extract component_specs_updates & config_specs_updates from `specs` + passed_components = {k: kwargs.pop(k) for k in self._component_specs if k in kwargs} + passed_config_values = {k: kwargs.pop(k) for k in self._config_specs if k in kwargs} + + for name, component in passed_components.items(): + if not hasattr(component, "_diffusers_load_id"): + raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") + + if len(kwargs) > 0: + logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}") + + + self.register_components(**passed_components) + + + config_to_register = {} + for name, new_value in passed_config_values.items(): + + # e.g. requires_aesthetics_score = False + self._config_specs[name].default = new_value + config_to_register[name] = new_value + self.register_to_config(**config_to_register) + + + # YiYi TODO: support map for additional from_pretrained kwargs + def load(self, component_names: Optional[List[str]] = None, **kwargs): + """ + Load selectedcomponents from specs. + + Args: + component_names: List of component names to load + **kwargs: additional kwargs to be passed to `from_pretrained()`.Can be: + - a single value to be applied to all components to be loaded, e.g. torch_dtype=torch.bfloat16 + - a dict, e.g. torch_dtype={"unet": torch.bfloat16, "default": torch.float32} + - if potentially override ComponentSpec if passed a different loading field in kwargs, e.g. `repo`, `variant`, `revision`, etc. + """ + if component_names is None: + component_names = list(self._component_specs.keys()) + elif not isinstance(component_names, list): + component_names = [component_names] + + components_to_load = set([name for name in component_names if name in self._component_specs]) + unknown_component_names = set([name for name in component_names if name not in self._component_specs]) + if len(unknown_component_names) > 0: + logger.warning(f"Unknown components will be ignored: {unknown_component_names}") + + components_to_register = {} + for name in components_to_load: + spec = self._component_specs[name] + component_load_kwargs = {} + for key, value in kwargs.items(): + if not isinstance(value, dict): + # if the value is a single value, apply it to all components + component_load_kwargs[key] = value + else: + if name in value: + # if it is a dict, check if the component name is in the dict + component_load_kwargs[key] = value[name] + elif "default" in value: + # check if the default is specified + component_load_kwargs[key] = value["default"] + try: + components_to_register[name] = spec.create(**component_load_kwargs) + except Exception as e: + logger.warning(f"Failed to create component '{name}': {e}") + + # Register all components at once + self.register_components(**components_to_register) + + # YiYi TODO: should support to method + def to(self, *args, **kwargs): + pass + + # YiYi TODO: + # 1. should support save some components too! currently only modular_model_index.json is saved + # 2. maybe order the json file to make it more readable: configs first, then components + def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, spec_only: bool = True, **kwargs): + + component_names = list(self._component_specs.keys()) + config_names = list(self._config_specs.keys()) + self.register_to_config(_components_names=component_names, _configs_names=config_names) + self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) + config = dict(self.config) + config.pop("_components_names", None) + config.pop("_configs_names", None) + self._internal_dict = FrozenDict(config) + + + @classmethod + @validate_hf_hub_args + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], spec_only: bool = True, **kwargs): + + config_dict = cls.load_config(pretrained_model_name_or_path, **kwargs) + expected_component = set(config_dict.pop("_components_names")) + expected_config = set(config_dict.pop("_configs_names")) + + component_specs = [] + config_specs = [] + for name, value in config_dict.items(): + if name in expected_component and isinstance(value, (tuple, list)) and len(value) == 3: + library, class_name, component_spec_dict = value + component_spec = cls._dict_to_component_spec(name, component_spec_dict) + component_specs.append(component_spec) + + elif name in expected_config: + config_specs.append(ConfigSpec(name=name, default=value)) + + for name in expected_component: + for spec in component_specs: + if spec.name == name: + break + else: + # append a empty component spec for these not in modular_model_index + component_specs.append(ComponentSpec(name=name, default_creation_method="from_config")) + return cls(component_specs + config_specs) + + + @staticmethod + def _component_spec_to_dict(component_spec: ComponentSpec) -> Any: + """ + Convert a ComponentSpec into a JSON‐serializable dict for saving in + `modular_model_index.json`. + + This dict contains: + - "type_hint": Tuple[str, str] + The fully‐qualified module path and class name of the component. + - All loading fields defined by `component_spec.loading_fields()`, typically: + - "repo": Optional[str] + The model repository (e.g., "stabilityai/stable-diffusion-xl"). + - "subfolder": Optional[str] + A subfolder within the repo where this component lives. + - "variant": Optional[str] + An optional variant identifier for the model. + - "revision": Optional[str] + A specific git revision (commit hash, tag, or branch). + - ... any other loading fields defined on the spec. + + Args: + component_spec (ComponentSpec): + The spec object describing one pipeline component. + + Returns: + Dict[str, Any]: A mapping suitable for JSON serialization. + + Example: + >>> from diffusers.pipelines.modular_pipeline_utils import ComponentSpec + >>> from diffusers.models.unet import UNet2DConditionModel + >>> spec = ComponentSpec( + ... name="unet", + ... type_hint=UNet2DConditionModel, + ... config=None, + ... repo="path/to/repo", + ... subfolder="subfolder", + ... variant=None, + ... revision=None, + ... default_creation_method="from_pretrained", + ... ) + >>> ModularLoader._component_spec_to_dict(spec) + { + "type_hint": ("diffusers.models.unet", "UNet2DConditionModel"), + "repo": "path/to/repo", + "subfolder": "subfolder", + "variant": None, + "revision": None, + } + """ + if component_spec.type_hint is not None: + lib_name, cls_name = _fetch_class_library_tuple(component_spec.type_hint) + else: + lib_name = None + cls_name = None + load_spec_dict = {k: getattr(component_spec, k) for k in component_spec.loading_fields()} + return { + "type_hint": (lib_name, cls_name), + **load_spec_dict, + } + + @staticmethod + def _dict_to_component_spec( + name: str, + spec_dict: Dict[str, Any], + ) -> ComponentSpec: + """ + Reconstruct a ComponentSpec from a dict. + """ + # make a shallow copy so we can pop() safely + spec_dict = spec_dict.copy() + # pull out and resolve the stored type_hint + lib_name, cls_name = spec_dict.pop("type_hint") + if lib_name is not None and cls_name is not None: + type_hint = simple_get_class_obj(lib_name, cls_name) + else: + type_hint = None + + # re‐assemble the ComponentSpec + return ComponentSpec( + name=name, + type_hint=type_hint, + **spec_dict, + ) \ No newline at end of file diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py new file mode 100644 index 000000000000..392d6dcd9521 --- /dev/null +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -0,0 +1,598 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import inspect +from dataclasses import dataclass, asdict, field, fields +from typing import Any, Dict, List, Optional, Tuple, Type, Union, Literal + +from ..utils.import_utils import is_torch_available +from ..configuration_utils import FrozenDict, ConfigMixin + +if is_torch_available(): + import torch + + +# YiYi TODO: +# 1. validate the dataclass fields +# 2. add a validator for create_* methods, make sure they are valid inputs to pass to from_pretrained() +@dataclass +class ComponentSpec: + """Specification for a pipeline component. + + A component can be created in two ways: + 1. From scratch using __init__ with a config dict + 2. using `from_pretrained` + + Attributes: + name: Name of the component + type_hint: Type of the component (e.g. UNet2DConditionModel) + description: Optional description of the component + config: Optional config dict for __init__ creation + repo: Optional repo path for from_pretrained creation + subfolder: Optional subfolder in repo + variant: Optional variant in repo + revision: Optional revision in repo + default_creation_method: Preferred creation method - "from_config" or "from_pretrained" + """ + name: Optional[str] = None + type_hint: Optional[Type] = None + description: Optional[str] = None + config: Optional[FrozenDict[str, Any]] = None + # YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name + repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True}) + subfolder: Optional[str] = field(default=None, metadata={"loading": True}) + variant: Optional[str] = field(default=None, metadata={"loading": True}) + revision: Optional[str] = field(default=None, metadata={"loading": True}) + default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained" + + + def __hash__(self): + """Make ComponentSpec hashable, using load_id as the hash value.""" + return hash((self.name, self.load_id, self.default_creation_method)) + + def __eq__(self, other): + """Compare ComponentSpec objects based on name and load_id.""" + if not isinstance(other, ComponentSpec): + return False + return (self.name == other.name and + self.load_id == other.load_id and + self.default_creation_method == other.default_creation_method) + + @classmethod + def from_component(cls, name: str, component: torch.nn.Module) -> Any: + """Create a ComponentSpec from a Component created by `create` method.""" + + if not hasattr(component, "_diffusers_load_id"): + raise ValueError("Component is not created by `create` method") + + type_hint = component.__class__ + + if component._diffusers_load_id == "null" and isinstance(component, ConfigMixin): + config = component.config + else: + config = None + + load_spec = cls.decode_load_id(component._diffusers_load_id) + + return cls(name=name, type_hint=type_hint, config=config, **load_spec) + + @classmethod + def from_load_id(cls, load_id: str, name: Optional[str] = None) -> Any: + """Create a ComponentSpec from a load_id string.""" + if load_id == "null": + raise ValueError("Cannot create ComponentSpec from null load_id") + + # Decode the load_id into a dictionary of loading fields + load_fields = cls.decode_load_id(load_id) + + # Create a new ComponentSpec instance with the decoded fields + return cls(name=name, **load_fields) + + @classmethod + def loading_fields(cls) -> List[str]: + """ + Return the names of all loading‐related fields + (i.e. those whose field.metadata["loading"] is True). + """ + return [f.name for f in fields(cls) if f.metadata.get("loading", False)] + + + @property + def load_id(self) -> str: + """ + Unique identifier for this spec's pretrained load, + composed of repo|subfolder|variant|revision (no empty segments). + """ + parts = [getattr(self, k) for k in self.loading_fields()] + parts = ["null" if p is None else p for p in parts] + return "|".join(p for p in parts if p) + + @classmethod + def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]: + """ + Decode a load_id string back into a dictionary of loading fields and values. + + Args: + load_id: The load_id string to decode, format: "repo|subfolder|variant|revision" + where None values are represented as "null" + + Returns: + Dict mapping loading field names to their values. e.g. + { + "repo": "path/to/repo", + "subfolder": "subfolder", + "variant": "variant", + "revision": "revision" + } + If a segment value is "null", it's replaced with None. + Returns None if load_id is "null" (indicating component not loaded from pretrained). + """ + + # Get all loading fields in order + loading_fields = cls.loading_fields() + result = {f: None for f in loading_fields} + + if load_id == "null": + return result + + # Split the load_id + parts = load_id.split("|") + + # Map parts to loading fields by position + for i, part in enumerate(parts): + if i < len(loading_fields): + # Convert "null" string back to None + result[loading_fields[i]] = None if part == "null" else part + + return result + + # YiYi TODO: add validator + def create(self, **kwargs) -> Any: + """Create the component using the preferred creation method.""" + + # from_pretrained creation + if self.default_creation_method == "from_pretrained": + return self.create_from_pretrained(**kwargs) + elif self.default_creation_method == "from_config": + # from_config creation + return self.create_from_config(**kwargs) + else: + raise ValueError(f"Invalid creation method: {self.default_creation_method}") + + def create_from_config(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any: + """Create component using from_config with config.""" + + if self.type_hint is None or not isinstance(self.type_hint, type): + raise ValueError( + f"`type_hint` is required when using from_config creation method." + ) + + config = config or self.config or {} + + if issubclass(self.type_hint, ConfigMixin): + component = self.type_hint.from_config(config, **kwargs) + else: + signature_params = inspect.signature(self.type_hint.__init__).parameters + init_kwargs = {} + for k, v in config.items(): + if k in signature_params: + init_kwargs[k] = v + for k, v in kwargs.items(): + if k in signature_params: + init_kwargs[k] = v + component = self.type_hint(**init_kwargs) + + component._diffusers_load_id = "null" + if hasattr(component, "config"): + self.config = component.config + + return component + + # YiYi TODO: add guard for type of model, if it is supported by from_pretrained + def create_from_pretrained(self, **kwargs) -> Any: + """Create component using from_pretrained.""" + + passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs} + load_kwargs = {key: passed_loading_kwargs.get(key, getattr(self, key)) for key in self.loading_fields()} + # repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path + repo = load_kwargs.pop("repo", None) + if repo is None: + raise ValueError(f"`repo` info is required when using from_pretrained creation method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)") + + if self.type_hint is None: + try: + from diffusers import AutoModel + component = AutoModel.from_pretrained(repo, **load_kwargs, **kwargs) + except Exception as e: + raise ValueError(f"Error creating {self.name} without `type_hint` from pretrained: {e}") + self.type_hint = component.__class__ + else: + try: + component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs) + except Exception as e: + raise ValueError(f"Error creating {self.name}[{self.type_hint.__name__}] from pretrained: {e}") + + if repo != self.repo: + self.repo = repo + for k, v in passed_loading_kwargs.items(): + if v is not None: + setattr(self, k, v) + component._diffusers_load_id = self.load_id + + return component + + + +@dataclass +class ConfigSpec: + """Specification for a pipeline configuration parameter.""" + name: str + default: Any + description: Optional[str] = None +@dataclass +class InputParam: + """Specification for an input parameter.""" + name: str = None + type_hint: Any = None + default: Any = None + required: bool = False + description: str = "" + kwargs_type: str = None + + def __repr__(self): + return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>" + + +@dataclass +class OutputParam: + """Specification for an output parameter.""" + name: str + type_hint: Any = None + description: str = "" + kwargs_type: str = None + + def __repr__(self): + return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>" + + +def format_inputs_short(inputs): + """ + Format input parameters into a string representation, with required params first followed by optional ones. + + Args: + inputs: List of input parameters with 'required' and 'name' attributes, and 'default' for optional params + + Returns: + str: Formatted string of input parameters + + Example: + >>> inputs = [ + ... InputParam(name="prompt", required=True), + ... InputParam(name="image", required=True), + ... InputParam(name="guidance_scale", required=False, default=7.5), + ... InputParam(name="num_inference_steps", required=False, default=50) + ... ] + >>> format_inputs_short(inputs) + 'prompt, image, guidance_scale=7.5, num_inference_steps=50' + """ + required_inputs = [param for param in inputs if param.required] + optional_inputs = [param for param in inputs if not param.required] + + required_str = ", ".join(param.name for param in required_inputs) + optional_str = ", ".join(f"{param.name}={param.default}" for param in optional_inputs) + + inputs_str = required_str + if optional_str: + inputs_str = f"{inputs_str}, {optional_str}" if required_str else optional_str + + return inputs_str + + +def format_intermediates_short(intermediates_inputs, required_intermediates_inputs, intermediates_outputs): + """ + Formats intermediate inputs and outputs of a block into a string representation. + + Args: + intermediates_inputs: List of intermediate input parameters + required_intermediates_inputs: List of required intermediate input names + intermediates_outputs: List of intermediate output parameters + + Returns: + str: Formatted string like: + Intermediates: + - inputs: Required(latents), dtype + - modified: latents # variables that appear in both inputs and outputs + - outputs: images # new outputs only + """ + # Handle inputs + input_parts = [] + for inp in intermediates_inputs: + if inp.name in required_intermediates_inputs: + input_parts.append(f"Required({inp.name})") + else: + if inp.name is None and inp.kwargs_type is not None: + inp_name = "*_" + inp.kwargs_type + else: + inp_name = inp.name + input_parts.append(inp_name) + + # Handle modified variables (appear in both inputs and outputs) + inputs_set = {inp.name for inp in intermediates_inputs} + modified_parts = [] + new_output_parts = [] + + for out in intermediates_outputs: + if out.name in inputs_set: + modified_parts.append(out.name) + else: + new_output_parts.append(out.name) + + result = [] + if input_parts: + result.append(f" - inputs: {', '.join(input_parts)}") + if modified_parts: + result.append(f" - modified: {', '.join(modified_parts)}") + if new_output_parts: + result.append(f" - outputs: {', '.join(new_output_parts)}") + + return "\n".join(result) if result else " (none)" + + +def format_params(params, header="Args", indent_level=4, max_line_length=115): + """Format a list of InputParam or OutputParam objects into a readable string representation. + + Args: + params: List of InputParam or OutputParam objects to format + header: Header text to use (e.g. "Args" or "Returns") + indent_level: Number of spaces to indent each parameter line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + + Returns: + A formatted string representing all parameters + """ + if not params: + return "" + + base_indent = " " * indent_level + param_indent = " " * (indent_level + 4) + desc_indent = " " * (indent_level + 8) + formatted_params = [] + + def get_type_str(type_hint): + if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Union: + types = [t.__name__ if hasattr(t, "__name__") else str(t) for t in type_hint.__args__] + return f"Union[{', '.join(types)}]" + return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint) + + def wrap_text(text, indent, max_length): + """Wrap text while preserving markdown links and maintaining indentation.""" + words = text.split() + lines = [] + current_line = [] + current_length = 0 + + for word in words: + word_length = len(word) + (1 if current_line else 0) + + if current_line and current_length + word_length > max_length: + lines.append(" ".join(current_line)) + current_line = [word] + current_length = len(word) + else: + current_line.append(word) + current_length += word_length + + if current_line: + lines.append(" ".join(current_line)) + + return f"\n{indent}".join(lines) + + # Add the header + formatted_params.append(f"{base_indent}{header}:") + + for param in params: + # Format parameter name and type + type_str = get_type_str(param.type_hint) if param.type_hint != Any else "" + param_str = f"{param_indent}{param.name} (`{type_str}`" + + # Add optional tag and default value if parameter is an InputParam and optional + if hasattr(param, "required"): + if not param.required: + param_str += ", *optional*" + if param.default is not None: + param_str += f", defaults to {param.default}" + param_str += "):" + + # Add description on a new line with additional indentation and wrapping + if param.description: + desc = re.sub( + r'\[(.*?)\]\((https?://[^\s\)]+)\)', + r'[\1](\2)', + param.description + ) + wrapped_desc = wrap_text(desc, desc_indent, max_line_length) + param_str += f"\n{desc_indent}{wrapped_desc}" + + formatted_params.append(param_str) + + return "\n\n".join(formatted_params) + + +def format_input_params(input_params, indent_level=4, max_line_length=115): + """Format a list of InputParam objects into a readable string representation. + + Args: + input_params: List of InputParam objects to format + indent_level: Number of spaces to indent each parameter line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + + Returns: + A formatted string representing all input parameters + """ + return format_params(input_params, "Inputs", indent_level, max_line_length) + + +def format_output_params(output_params, indent_level=4, max_line_length=115): + """Format a list of OutputParam objects into a readable string representation. + + Args: + output_params: List of OutputParam objects to format + indent_level: Number of spaces to indent each parameter line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + + Returns: + A formatted string representing all output parameters + """ + return format_params(output_params, "Outputs", indent_level, max_line_length) + + +def format_components(components, indent_level=4, max_line_length=115, add_empty_lines=True): + """Format a list of ComponentSpec objects into a readable string representation. + + Args: + components: List of ComponentSpec objects to format + indent_level: Number of spaces to indent each component line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + add_empty_lines: Whether to add empty lines between components (default: True) + + Returns: + A formatted string representing all components + """ + if not components: + return "" + + base_indent = " " * indent_level + component_indent = " " * (indent_level + 4) + formatted_components = [] + + # Add the header + formatted_components.append(f"{base_indent}Components:") + if add_empty_lines: + formatted_components.append("") + + # Add each component with optional empty lines between them + for i, component in enumerate(components): + # Get type name, handling special cases + type_name = component.type_hint.__name__ if hasattr(component.type_hint, "__name__") else str(component.type_hint) + + component_desc = f"{component_indent}{component.name} (`{type_name}`)" + if component.description: + component_desc += f": {component.description}" + + # Get the loading fields dynamically + loading_field_values = [] + for field_name in component.loading_fields(): + field_value = getattr(component, field_name) + if field_value is not None: + loading_field_values.append(f"{field_name}={field_value}") + + # Add loading field information if available + if loading_field_values: + component_desc += f" [{', '.join(loading_field_values)}]" + + formatted_components.append(component_desc) + + # Add an empty line after each component except the last one + if add_empty_lines and i < len(components) - 1: + formatted_components.append("") + + return "\n".join(formatted_components) + + +def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines=True): + """Format a list of ConfigSpec objects into a readable string representation. + + Args: + configs: List of ConfigSpec objects to format + indent_level: Number of spaces to indent each config line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + add_empty_lines: Whether to add empty lines between configs (default: True) + + Returns: + A formatted string representing all configs + """ + if not configs: + return "" + + base_indent = " " * indent_level + config_indent = " " * (indent_level + 4) + formatted_configs = [] + + # Add the header + formatted_configs.append(f"{base_indent}Configs:") + if add_empty_lines: + formatted_configs.append("") + + # Add each config with optional empty lines between them + for i, config in enumerate(configs): + config_desc = f"{config_indent}{config.name} (default: {config.default})" + if config.description: + config_desc += f": {config.description}" + formatted_configs.append(config_desc) + + # Add an empty line after each config except the last one + if add_empty_lines and i < len(configs) - 1: + formatted_configs.append("") + + return "\n".join(formatted_configs) + + +def make_doc_string(inputs, intermediates_inputs, outputs, description="", class_name=None, expected_components=None, expected_configs=None): + """ + Generates a formatted documentation string describing the pipeline block's parameters and structure. + + Args: + inputs: List of input parameters + intermediates_inputs: List of intermediate input parameters + outputs: List of output parameters + description (str, *optional*): Description of the block + class_name (str, *optional*): Name of the class to include in the documentation + expected_components (List[ComponentSpec], *optional*): List of expected components + expected_configs (List[ConfigSpec], *optional*): List of expected configurations + + Returns: + str: A formatted string containing information about components, configs, call parameters, + intermediate inputs/outputs, and final outputs. + """ + output = "" + + # Add class name if provided + if class_name: + output += f"class {class_name}\n\n" + + # Add description + if description: + desc_lines = description.strip().split('\n') + aligned_desc = '\n'.join(' ' + line for line in desc_lines) + output += aligned_desc + "\n\n" + + # Add components section if provided + if expected_components and len(expected_components) > 0: + components_str = format_components(expected_components, indent_level=2) + output += components_str + "\n\n" + + # Add configs section if provided + if expected_configs and len(expected_configs) > 0: + configs_str = format_configs(expected_configs, indent_level=2) + output += configs_str + "\n\n" + + # Add inputs section + output += format_input_params(inputs + intermediates_inputs, indent_level=2) + + # Add outputs section + output += "\n\n" + output += format_output_params(outputs, indent_level=2) + + return output \ No newline at end of file diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py new file mode 100644 index 000000000000..6d06c1f2e3df --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py @@ -0,0 +1,51 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modular_pipeline_presets"] = ["StableDiffusionXLAutoPipeline"] + _import_structure["modular_loader"] = ["StableDiffusionXLModularLoader"] + _import_structure["encoders"] = ["StableDiffusionXLAutoIPAdapterStep", "StableDiffusionXLTextEncoderStep", "StableDiffusionXLAutoVaeEncoderStep"] + _import_structure["after_denoise"] = ["StableDiffusionXLAutoDecodeStep"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .modular_pipeline_presets import StableDiffusionXLAutoPipeline + from .modular_loader import StableDiffusionXLModularLoader + from .encoders import StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoVaeEncoderStep + from .after_denoise import StableDiffusionXLAutoDecodeStep +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/after_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/after_denoise.py new file mode 100644 index 000000000000..9746832506d7 --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/after_denoise.py @@ -0,0 +1,259 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, List, Optional, Tuple, Union, Dict + +import PIL +import torch +import numpy as np +from collections import OrderedDict + +from ...image_processor import VaeImageProcessor, PipelineImageInput +from ...models import AutoencoderKL +from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor +from ...utils import logging + +from ...pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from ...configuration_utils import FrozenDict + +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from ..modular_pipeline import ( + AutoPipelineBlocks, + PipelineBlock, + PipelineState, + SequentialPipelineBlocks, +) + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + + + +class StableDiffusionXLDecodeLatentsStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config"), + ] + + @property + def description(self) -> str: + return "Step that decodes the denoised latents into images" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("output_type", default="pil"), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [InputParam("latents", required=True, type_hint=torch.Tensor, description="The denoised latents from the denoising step")] + + @property + def intermediates_outputs(self) -> List[str]: + return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array")] + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae with self -> components + @staticmethod + def upcast_vae(components): + dtype = components.vae.dtype + components.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + components.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + components.vae.post_quant_conv.to(dtype) + components.vae.decoder.conv_in.to(dtype) + components.vae.decoder.mid_block.to(dtype) + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + if not block_state.output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + block_state.needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast + + if block_state.needs_upcasting: + self.upcast_vae(components) + block_state.latents = block_state.latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype) + elif block_state.latents.dtype != components.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + components.vae = components.vae.to(block_state.latents.dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + block_state.has_latents_mean = ( + hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None + ) + block_state.has_latents_std = ( + hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None + ) + if block_state.has_latents_mean and block_state.has_latents_std: + block_state.latents_mean = ( + torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1).to(block_state.latents.device, block_state.latents.dtype) + ) + block_state.latents_std = ( + torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(block_state.latents.device, block_state.latents.dtype) + ) + block_state.latents = block_state.latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean + else: + block_state.latents = block_state.latents / components.vae.config.scaling_factor + + block_state.images = components.vae.decode(block_state.latents, return_dict=False)[0] + + # cast back to fp16 if needed + if block_state.needs_upcasting: + components.vae.to(dtype=torch.float16) + else: + block_state.images = block_state.latents + + # apply watermark if available + if hasattr(components, "watermark") and components.watermark is not None: + block_state.images = components.watermark.apply_watermark(block_state.images) + + block_state.images = components.image_processor.postprocess(block_state.images, output_type=block_state.output_type) + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return "A post-processing step that overlays the mask on the image (inpainting task only).\n" + \ + "only needed when you are using the `padding_mask_crop` option when pre-processing the image and mask" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("image", required=True), + InputParam("mask_image", required=True), + InputParam("padding_mask_crop"), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam("images", required=True, type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images from the decode step"), + InputParam("crops_coords", required=True, type_hint=Tuple[int, int], description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.") + ] + + @property + def intermediates_outputs(self) -> List[str]: + return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images with the mask overlayed")] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + if block_state.padding_mask_crop is not None and block_state.crops_coords is not None: + block_state.images = [components.image_processor.apply_overlay(block_state.mask_image, block_state.image, i, block_state.crops_coords) for i in block_state.images] + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLOutputStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return "final step to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or a plain tuple." + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [InputParam("return_dict", default=True)] + + @property + def intermediates_inputs(self) -> List[str]: + return [InputParam("images", required=True, type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images from the decode step.")] + + @property + def intermediates_outputs(self) -> List[str]: + return [OutputParam("images", description="The final images output, can be a tuple or a `StableDiffusionXLPipelineOutput`")] + + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + if not block_state.return_dict: + block_state.images = (block_state.images,) + else: + block_state.images = StableDiffusionXLPipelineOutput(images=block_state.images) + self.add_block_state(state, block_state) + return components, state + + +# After denoise +class StableDiffusionXLDecodeStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLDecodeLatentsStep, StableDiffusionXLOutputStep] + block_names = ["decode", "output"] + + @property + def description(self): + return """Decode step that decode the denoised latents into images outputs. +This is a sequential pipeline blocks: + - `StableDiffusionXLDecodeLatentsStep` is used to decode the denoised latents into images + - `StableDiffusionXLOutputStep` is used to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or a plain tuple.""" + + +class StableDiffusionXLInpaintDecodeStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLDecodeLatentsStep, StableDiffusionXLInpaintOverlayMaskStep, StableDiffusionXLOutputStep] + block_names = ["decode", "mask_overlay", "output"] + + @property + def description(self): + return "Inpaint decode step that decode the denoised latents into images outputs.\n" + \ + "This is a sequential pipeline blocks:\n" + \ + " - `StableDiffusionXLDecodeLatentsStep` is used to decode the denoised latents into images\n" + \ + " - `StableDiffusionXLInpaintOverlayMaskStep` is used to overlay the mask on the image\n" + \ + " - `StableDiffusionXLOutputStep` is used to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or a plain tuple." + + +class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintDecodeStep, StableDiffusionXLDecodeStep] + block_names = ["inpaint", "non-inpaint"] + block_trigger_inputs = ["padding_mask_crop", None] + + @property + def description(self): + return "Decode step that decode the denoised latents into images outputs.\n" + \ + "This is an auto pipeline block that works for inpainting and non-inpainting tasks.\n" + \ + " - `StableDiffusionXLInpaintDecodeStep` (inpaint) is used when `padding_mask_crop` is provided.\n" + \ + " - `StableDiffusionXLDecodeStep` (non-inpaint) is used when `padding_mask_crop` is not provided." + + diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py new file mode 100644 index 000000000000..6809b4cd8e2e --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py @@ -0,0 +1,1766 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, List, Optional, Tuple, Union, Dict + +import PIL +import torch +from collections import OrderedDict + +from ...image_processor import VaeImageProcessor, PipelineImageInput +from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin +from ...models import ControlNetModel, UNet2DConditionModel, AutoencoderKL, ControlNetUnionModel +from ...utils import logging +from ...utils.torch_utils import randn_tensor, unwrap_module + +from ...pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ...pipelines.controlnet.multicontrolnet import MultiControlNetModel +from ...schedulers import EulerDiscreteScheduler +from ...configuration_utils import FrozenDict + +from .modular_loader import StableDiffusionXLModularLoader +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from ..modular_pipeline import ( + AutoPipelineBlocks, + ModularLoader, + PipelineBlock, + PipelineState, + SequentialPipelineBlocks, +) + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + + + + +# TODO(yiyi, aryan): We need another step before text encoder to set the `num_inference_steps` attribute for guider so that +# things like when to do guidance and how many conditions to be prepared can be determined. Currently, this is done by +# always assuming you want to do guidance in the Guiders. So, negative embeddings are prepared regardless of what the +# configuration of guider is. + + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class StableDiffusionXLInputStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return ( + "Input processing step that:\n" + " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" + " 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`\n\n" + "All input tensors are expected to have either batch_size=1 or match the batch_size\n" + "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n" + "have a final batch_size of batch_size * num_images_per_prompt." + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_images_per_prompt", default=1), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam("prompt_embeds", required=True, type_hint=torch.Tensor, description="Pre-generated text embeddings. Can be generated from text_encoder step."), + InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Pre-generated negative text embeddings. Can be generated from text_encoder step."), + InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor, description="Pre-generated pooled text embeddings. Can be generated from text_encoder step."), + InputParam("negative_pooled_prompt_embeds", description="Pre-generated negative pooled text embeddings. Can be generated from text_encoder step."), + InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Pre-generated image embeddings for IP-Adapter. Can be generated from ip_adapter step."), + InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Pre-generated negative image embeddings for IP-Adapter. Can be generated from ip_adapter step."), + ] + + @property + def intermediates_outputs(self) -> List[str]: + return [ + OutputParam("batch_size", type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt"), + OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs (determined by `prompt_embeds`)"), + OutputParam("prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="text embeddings used to guide the image generation"), + OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative text embeddings used to guide the image generation"), + OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="pooled text embeddings used to guide the image generation"), + OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative pooled text embeddings used to guide the image generation"), + OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], kwargs_type="guider_input_fields", description="image embeddings for IP-Adapter"), + OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], kwargs_type="guider_input_fields", description="negative image embeddings for IP-Adapter"), + ] + + def check_inputs(self, components, block_state): + + if block_state.prompt_embeds is not None and block_state.negative_prompt_embeds is not None: + if block_state.prompt_embeds.shape != block_state.negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `negative_prompt_embeds`" + f" {block_state.negative_prompt_embeds.shape}." + ) + + if block_state.prompt_embeds is not None and block_state.pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if block_state.negative_prompt_embeds is not None and block_state.negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if block_state.ip_adapter_embeds is not None and not isinstance(block_state.ip_adapter_embeds, list): + raise ValueError("`ip_adapter_embeds` must be a list") + + if block_state.negative_ip_adapter_embeds is not None and not isinstance(block_state.negative_ip_adapter_embeds, list): + raise ValueError("`negative_ip_adapter_embeds` must be a list") + + if block_state.ip_adapter_embeds is not None and block_state.negative_ip_adapter_embeds is not None: + for i, ip_adapter_embed in enumerate(block_state.ip_adapter_embeds): + if ip_adapter_embed.shape != block_state.negative_ip_adapter_embeds[i].shape: + raise ValueError( + "`ip_adapter_embeds` and `negative_ip_adapter_embeds` must have the same shape when passed directly, but" + f" got: `ip_adapter_embeds` {ip_adapter_embed.shape} != `negative_ip_adapter_embeds`" + f" {block_state.negative_ip_adapter_embeds[i].shape}." + ) + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + block_state.batch_size = block_state.prompt_embeds.shape[0] + block_state.dtype = block_state.prompt_embeds.dtype + + _, seq_len, _ = block_state.prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.prompt_embeds = block_state.prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1) + + if block_state.negative_prompt_embeds is not None: + _, seq_len, _ = block_state.negative_prompt_embeds.shape + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1) + + block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1) + + if block_state.negative_pooled_prompt_embeds is not None: + block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1) + + if block_state.ip_adapter_embeds is not None: + for i, ip_adapter_embed in enumerate(block_state.ip_adapter_embeds): + block_state.ip_adapter_embeds[i] = torch.cat([ip_adapter_embed] * block_state.num_images_per_prompt, dim=0) + + if block_state.negative_ip_adapter_embeds is not None: + for i, negative_ip_adapter_embed in enumerate(block_state.negative_ip_adapter_embeds): + block_state.negative_ip_adapter_embeds[i] = torch.cat([negative_ip_adapter_embed] * block_state.num_images_per_prompt, dim=0) + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return ( + "Step that sets the timesteps for the scheduler and determines the initial noise level (latent_timestep) for image-to-image/inpainting generation.\n" + \ + "The latent_timestep is calculated from the `strength` parameter - higher strength means starting from a noisier version of the input image." + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_inference_steps", default=50), + InputParam("timesteps"), + InputParam("sigmas"), + InputParam("denoising_end"), + InputParam("strength", default=0.3), + InputParam("denoising_start"), + # YiYi TODO: do we need num_images_per_prompt here? + InputParam("num_images_per_prompt", default=1), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt"), + ] + + @property + def intermediates_outputs(self) -> List[str]: + return [ + OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), + OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time"), + OutputParam("latent_timestep", type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image generation") + ] + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps with self -> components + def get_timesteps(self, components, num_inference_steps, strength, device, denoising_start=None): + # get the original timestep using init_timestep + if denoising_start is None: + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + + timesteps = components.scheduler.timesteps[t_start * components.scheduler.order :] + if hasattr(components.scheduler, "set_begin_index"): + components.scheduler.set_begin_index(t_start * components.scheduler.order) + + return timesteps, num_inference_steps - t_start + + else: + # Strength is irrelevant if we directly request a timestep to start at; + # that is, strength is determined by the denoising_start instead. + discrete_timestep_cutoff = int( + round( + components.scheduler.config.num_train_timesteps + - (denoising_start * components.scheduler.config.num_train_timesteps) + ) + ) + + num_inference_steps = (components.scheduler.timesteps < discrete_timestep_cutoff).sum().item() + if components.scheduler.order == 2 and num_inference_steps % 2 == 0: + # if the scheduler is a 2nd order scheduler we might have to do +1 + # because `num_inference_steps` might be even given that every timestep + # (except the highest one) is duplicated. If `num_inference_steps` is even it would + # mean that we cut the timesteps in the middle of the denoising step + # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 + # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler + num_inference_steps = num_inference_steps + 1 + + # because t_n+1 >= t_n, we slice the timesteps starting from the end + t_start = len(components.scheduler.timesteps) - num_inference_steps + timesteps = components.scheduler.timesteps[t_start:] + if hasattr(components.scheduler, "set_begin_index"): + components.scheduler.set_begin_index(t_start) + return timesteps, num_inference_steps + + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.device = components._execution_device + + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + components.scheduler, block_state.num_inference_steps, block_state.device, block_state.timesteps, block_state.sigmas + ) + + def denoising_value_valid(dnv): + return isinstance(dnv, float) and 0 < dnv < 1 + + block_state.timesteps, block_state.num_inference_steps = self.get_timesteps( + components, + block_state.num_inference_steps, + block_state.strength, + block_state.device, + denoising_start=block_state.denoising_start if denoising_value_valid(block_state.denoising_start) else None, + ) + block_state.latent_timestep = block_state.timesteps[:1].repeat(block_state.batch_size * block_state.num_images_per_prompt) + + if block_state.denoising_end is not None and isinstance(block_state.denoising_end, float) and block_state.denoising_end > 0 and block_state.denoising_end < 1: + block_state.discrete_timestep_cutoff = int( + round( + components.scheduler.config.num_train_timesteps + - (block_state.denoising_end * components.scheduler.config.num_train_timesteps) + ) + ) + block_state.num_inference_steps = len(list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps))) + block_state.timesteps = block_state.timesteps[:block_state.num_inference_steps] + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLSetTimestepsStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return ( + "Step that sets the scheduler's timesteps for inference" + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_inference_steps", default=50), + InputParam("timesteps"), + InputParam("sigmas"), + InputParam("denoising_end"), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), + OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time")] + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.device = components._execution_device + + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + components.scheduler, block_state.num_inference_steps, block_state.device, block_state.timesteps, block_state.sigmas + ) + + if block_state.denoising_end is not None and isinstance(block_state.denoising_end, float) and block_state.denoising_end > 0 and block_state.denoising_end < 1: + block_state.discrete_timestep_cutoff = int( + round( + components.scheduler.config.num_train_timesteps + - (block_state.denoising_end * components.scheduler.config.num_train_timesteps) + ) + ) + block_state.num_inference_steps = len(list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps))) + block_state.timesteps = block_state.timesteps[:block_state.num_inference_steps] + + self.add_block_state(state, block_state) + return components, state + + +class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return ( + "Step that prepares the latents for the inpainting process" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("generator"), + InputParam("latents"), + InputParam("num_images_per_prompt", default=1), + InputParam("denoising_start"), + InputParam( + "strength", + default=0.9999, + description="Conceptually, indicates how much to transform the reference `image` (the masked portion of image for inpainting). Must be between 0 and 1. `image` " + "will be used as a starting point, adding more noise to it the larger the `strength`. The number of " + "denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will " + "be maximum and the denoising process will run for the full number of iterations specified in " + "`num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of " + "`denoising_start` being declared as an integer, the value of `strength` will be ignored." + ), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + ), + InputParam( + "latent_timestep", + required=True, + type_hint=torch.Tensor, + description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step." + ), + InputParam( + "image_latents", + required=True, + type_hint=torch.Tensor, + description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step." + ), + InputParam( + "mask", + required=True, + type_hint=torch.Tensor, + description="The mask for the inpainting generation. Can be generated in vae_encode step." + ), + InputParam( + "masked_image_latents", + type_hint=torch.Tensor, + description="The masked image latents for the inpainting generation (only for inpainting-specific unet). Can be generated in vae_encode step." + ), + InputParam( + "dtype", + type_hint=torch.dtype, + description="The dtype of the model inputs" + ) + ] + + @property + def intermediates_outputs(self) -> List[str]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"), + OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for inpainting generation"), + OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting generation (only for inpainting-specific unet)"), + OutputParam("noise", type_hint=torch.Tensor, description="The noise added to the image latents, used for inpainting generation")] + + + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components + # YiYi TODO: update the _encode_vae_image so that we can use #Coped from + @staticmethod + def _encode_vae_image(components, image: torch.Tensor, generator: torch.Generator): + + latents_mean = latents_std = None + if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: + latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: + latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) + + dtype = image.dtype + if components.vae.config.force_upcast: + image = image.float() + components.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(components.vae.encode(image), generator=generator) + + if components.vae.config.force_upcast: + components.vae.to(dtype) + + image_latents = image_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) + latents_std = latents_std.to(device=image_latents.device, dtype=dtype) + image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std + else: + image_latents = components.vae.config.scaling_factor * image_latents + + return image_latents + + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents adding components as first argument + def prepare_latents_inpaint( + self, + components, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + add_noise=True, + return_noise=False, + return_image_latents=False, + ): + shape = ( + batch_size, + num_channels_latents, + int(height) // components.vae_scale_factor, + int(width) // components.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if image.shape[1] == 4: + image_latents = image.to(device=device, dtype=dtype) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + elif return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(components, image=image, generator=generator) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + + if latents is None and add_noise: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else components.scheduler.add_noise(image_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * components.scheduler.init_noise_sigma if is_strength_max else latents + elif add_noise: + noise = latents.to(device) + latents = noise * components.scheduler.init_noise_sigma + else: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = image_latents.to(device) + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents + # do not accept do_classifier_free_guidance + def prepare_mask_latents( + self, components, mask, masked_image, batch_size, height, width, dtype, device, generator + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + + if masked_image is not None and masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + masked_image_latents = None + + if masked_image is not None: + if masked_image_latents is None: + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator) + + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat( + batch_size // masked_image_latents.shape[0], 1, 1, 1 + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + return mask, masked_image_latents + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + block_state.device = components._execution_device + + block_state.is_strength_max = block_state.strength == 1.0 + + # for non-inpainting specific unet, we do not need masked_image_latents + if hasattr(components,"unet") and components.unet is not None: + if components.unet.config.in_channels == 4: + block_state.masked_image_latents = None + + block_state.add_noise = True if block_state.denoising_start is None else False + + block_state.height = block_state.image_latents.shape[-2] * components.vae_scale_factor + block_state.width = block_state.image_latents.shape[-1] * components.vae_scale_factor + + block_state.latents, block_state.noise = self.prepare_latents_inpaint( + components, + block_state.batch_size * block_state.num_images_per_prompt, + components.num_channels_latents, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.latents, + image=block_state.image_latents, + timestep=block_state.latent_timestep, + is_strength_max=block_state.is_strength_max, + add_noise=block_state.add_noise, + return_noise=True, + return_image_latents=False, + ) + + # 7. Prepare mask latent variables + block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents( + components, + block_state.mask, + block_state.masked_image_latents, + block_state.batch_size * block_state.num_images_per_prompt, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + ) + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return ( + "Step that prepares the latents for the image-to-image generation process" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("generator"), + InputParam("latents"), + InputParam("num_images_per_prompt", default=1), + InputParam("denoising_start"), + ] + + @property + def intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam("latent_timestep", required=True, type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step."), + InputParam("image_latents", required=True, type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step."), + InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."), + InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs")] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process")] + + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents with self -> components + # YiYi TODO: refactor using _encode_vae_image + @staticmethod + def prepare_latents_img2img( + components, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + latents_mean = latents_std = None + if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: + latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: + latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) + # make sure the VAE is in float32 mode, as it overflows in float16 + if components.vae.config.force_upcast: + image = image.float() + components.vae.to(dtype=torch.float32) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + + init_latents = [ + retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(components.vae.encode(image), generator=generator) + + if components.vae.config.force_upcast: + components.vae.to(dtype) + + init_latents = init_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=device, dtype=dtype) + latents_std = latents_std.to(device=device, dtype=dtype) + init_latents = (init_latents - latents_mean) * components.vae.config.scaling_factor / latents_std + else: + init_latents = components.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + if add_noise: + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # get latents + init_latents = components.scheduler.add_noise(init_latents, noise, timestep) + + latents = init_latents + + return latents + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + block_state.device = components._execution_device + block_state.add_noise = True if block_state.denoising_start is None else False + if block_state.latents is None: + block_state.latents = self.prepare_latents_img2img( + components, + block_state.image_latents, + block_state.latent_timestep, + block_state.batch_size, + block_state.num_images_per_prompt, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.add_noise, + ) + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLPrepareLatentsStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return ( + "Prepare latents step that prepares the latents for the text-to-image generation process" + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("height"), + InputParam("width"), + InputParam("generator"), + InputParam("latents"), + InputParam("num_images_per_prompt", default=1), + ] + + @property + def intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + ), + InputParam( + "dtype", + type_hint=torch.dtype, + description="The dtype of the model inputs" + ) + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "latents", + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process" + ) + ] + + + @staticmethod + def check_inputs(components, block_state): + if ( + block_state.height is not None + and block_state.height % components.vae_scale_factor != 0 + or block_state.width is not None + and block_state.width % components.vae_scale_factor != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {components.vae_scale_factor} but are {block_state.height} and {block_state.width}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with self -> components + @staticmethod + def prepare_latents(components, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // components.vae_scale_factor, + int(width) // components.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * components.scheduler.init_noise_sigma + return latents + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + if block_state.dtype is None: + block_state.dtype = components.vae.dtype + + block_state.device = components._execution_device + + self.check_inputs(components, block_state) + + block_state.height = block_state.height or components.default_sample_size * components.vae_scale_factor + block_state.width = block_state.width or components.default_sample_size * components.vae_scale_factor + block_state.num_channels_latents = components.num_channels_latents + block_state.latents = self.prepare_latents( + components, + block_state.batch_size * block_state.num_images_per_prompt, + block_state.num_channels_latents, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.latents, + ) + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_configs(self) -> List[ConfigSpec]: + return [ConfigSpec("requires_aesthetics_score", False),] + + @property + def description(self) -> str: + return ( + "Step that prepares the additional conditioning for the image-to-image/inpainting generation process" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("original_size"), + InputParam("target_size"), + InputParam("negative_original_size"), + InputParam("negative_target_size"), + InputParam("crops_coords_top_left", default=(0, 0)), + InputParam("negative_crops_coords_top_left", default=(0, 0)), + InputParam("num_images_per_prompt", default=1), + InputParam("aesthetic_score", default=6.0), + InputParam("negative_aesthetic_score", default=2.0), + ] + + @property + def intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam("latents", required=True, type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."), + InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor, description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step."), + InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The time ids to condition the denoising process"), + OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process"), + OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids with self -> components + @staticmethod + def _get_add_time_ids_img2img( + components, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype, + text_encoder_projection_dim=None, + ): + if components.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list( + negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) + ) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) + + passed_add_embed_dim = ( + components.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = components.unet.add_embedding.linear_1.in_features + + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == components.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == components.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + @staticmethod + def get_guidance_scale_embedding( + w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + + block_state.vae_scale_factor = components.vae_scale_factor + + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * block_state.vae_scale_factor + block_state.width = block_state.width * block_state.vae_scale_factor + + block_state.original_size = block_state.original_size or (block_state.height, block_state.width) + block_state.target_size = block_state.target_size or (block_state.height, block_state.width) + + block_state.text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1]) + + if block_state.negative_original_size is None: + block_state.negative_original_size = block_state.original_size + if block_state.negative_target_size is None: + block_state.negative_target_size = block_state.target_size + + block_state.add_time_ids, block_state.negative_add_time_ids = self._get_add_time_ids_img2img( + components, + block_state.original_size, + block_state.crops_coords_top_left, + block_state.target_size, + block_state.aesthetic_score, + block_state.negative_aesthetic_score, + block_state.negative_original_size, + block_state.negative_crops_coords_top_left, + block_state.negative_target_size, + dtype=block_state.pooled_prompt_embeds.dtype, + text_encoder_projection_dim=block_state.text_encoder_projection_dim, + ) + block_state.add_time_ids = block_state.add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) + block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) + + # Optionally get Guidance Scale Embedding for LCM + block_state.timestep_cond = None + if ( + hasattr(components, "unet") + and components.unet is not None + and components.unet.config.time_cond_proj_dim is not None + ): + # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! + block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat(block_state.batch_size * block_state.num_images_per_prompt) + block_state.timestep_cond = self.get_guidance_scale_embedding( + block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim + ).to(device=block_state.device, dtype=block_state.latents.dtype) + + self.add_block_state(state, block_state) + return components, state + + +class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return ( + "Step that prepares the additional conditioning for the text-to-image generation process" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("original_size"), + InputParam("target_size"), + InputParam("negative_original_size"), + InputParam("negative_target_size"), + InputParam("crops_coords_top_left", default=(0, 0)), + InputParam("negative_crops_coords_top_left", default=(0, 0)), + InputParam("num_images_per_prompt", default=1), + ] + + @property + def intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + InputParam( + "pooled_prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step." + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The time ids to condition the denoising process"), + OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process"), + OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids with self -> components + @staticmethod + def _get_add_time_ids( + components, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + components.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = components.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + @staticmethod + def get_guidance_scale_embedding( + w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * components.vae_scale_factor + block_state.width = block_state.width * components.vae_scale_factor + + block_state.original_size = block_state.original_size or (block_state.height, block_state.width) + block_state.target_size = block_state.target_size or (block_state.height, block_state.width) + + block_state.text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1]) + + block_state.add_time_ids = self._get_add_time_ids( + components, + block_state.original_size, + block_state.crops_coords_top_left, + block_state.target_size, + block_state.pooled_prompt_embeds.dtype, + text_encoder_projection_dim=block_state.text_encoder_projection_dim, + ) + if block_state.negative_original_size is not None and block_state.negative_target_size is not None: + block_state.negative_add_time_ids = self._get_add_time_ids( + components, + block_state.negative_original_size, + block_state.negative_crops_coords_top_left, + block_state.negative_target_size, + block_state.pooled_prompt_embeds.dtype, + text_encoder_projection_dim=block_state.text_encoder_projection_dim, + ) + else: + block_state.negative_add_time_ids = block_state.add_time_ids + + block_state.add_time_ids = block_state.add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) + block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) + + # Optionally get Guidance Scale Embedding for LCM + block_state.timestep_cond = None + if ( + hasattr(components, "unet") + and components.unet is not None + and components.unet.config.time_cond_proj_dim is not None + ): + # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! + block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat(block_state.batch_size * block_state.num_images_per_prompt) + block_state.timestep_cond = self.get_guidance_scale_embedding( + block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim + ).to(device=block_state.device, dtype=block_state.latents.dtype) + + self.add_block_state(state, block_state) + return components, state + + +class StableDiffusionXLControlNetInputStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("controlnet", ControlNetModel), + ComponentSpec("control_image_processor", VaeImageProcessor, config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), default_creation_method="from_config"), + ] + + @property + def description(self) -> str: + return "step that prepare inputs for controlnet" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("control_image", required=True), + InputParam("control_guidance_start", default=0.0), + InputParam("control_guidance_end", default=1.0), + InputParam("controlnet_conditioning_scale", default=1.0), + InputParam("guess_mode", default=False), + InputParam("num_images_per_prompt", default=1), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + ), + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + "crops_coords", + type_hint=Optional[Tuple[int]], + description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [ + OutputParam("controlnet_cond", type_hint=torch.Tensor, description="The processed control image"), + OutputParam("control_guidance_start", type_hint=List[float], description="The controlnet guidance start values"), + OutputParam("control_guidance_end", type_hint=List[float], description="The controlnet guidance end values"), + OutputParam("conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values"), + OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"), + OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"), + ] + + + + # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image + # 1. return image without apply any guidance + # 2. add crops_coords and resize_mode to preprocess() + @staticmethod + def prepare_control_image( + components, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + crops_coords=None, + ): + if crops_coords is not None: + image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) + else: + image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + + image_batch_size = image.shape[0] + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + image = image.to(device=device, dtype=dtype) + return image + + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + + block_state = self.get_block_state(state) + + # (1) prepare controlnet inputs + block_state.device = components._execution_device + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * components.vae_scale_factor + block_state.width = block_state.width * components.vae_scale_factor + + controlnet = unwrap_module(components.controlnet) + + # (1.1) + # control_guidance_start/control_guidance_end (align format) + if not isinstance(block_state.control_guidance_start, list) and isinstance(block_state.control_guidance_end, list): + block_state.control_guidance_start = len(block_state.control_guidance_end) * [block_state.control_guidance_start] + elif not isinstance(block_state.control_guidance_end, list) and isinstance(block_state.control_guidance_start, list): + block_state.control_guidance_end = len(block_state.control_guidance_start) * [block_state.control_guidance_end] + elif not isinstance(block_state.control_guidance_start, list) and not isinstance(block_state.control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + block_state.control_guidance_start, block_state.control_guidance_end = ( + mult * [block_state.control_guidance_start], + mult * [block_state.control_guidance_end], + ) + + # (1.2) + # controlnet_conditioning_scale (align format) + if isinstance(controlnet, MultiControlNetModel) and isinstance(block_state.controlnet_conditioning_scale, float): + block_state.controlnet_conditioning_scale = [block_state.controlnet_conditioning_scale] * len(controlnet.nets) + + # (1.3) + # global_pool_conditions + block_state.global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + # (1.4) + # guess_mode + block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions + + # (1.5) + # control_image + if isinstance(controlnet, ControlNetModel): + block_state.control_image = self.prepare_control_image( + components, + image=block_state.control_image, + width=block_state.width, + height=block_state.height, + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_images_per_prompt=block_state.num_images_per_prompt, + device=block_state.device, + dtype=controlnet.dtype, + crops_coords=block_state.crops_coords, + ) + elif isinstance(controlnet, MultiControlNetModel): + control_images = [] + + for control_image_ in block_state.control_image: + control_image = self.prepare_control_image( + components, + image=control_image_, + width=block_state.width, + height=block_state.height, + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_images_per_prompt=block_state.num_images_per_prompt, + device=block_state.device, + dtype=controlnet.dtype, + crops_coords=block_state.crops_coords, + ) + + control_images.append(control_image) + + block_state.control_image = control_images + else: + assert False + + # (1.6) + # controlnet_keep + block_state.controlnet_keep = [] + for i in range(len(block_state.timesteps)): + keeps = [ + 1.0 - float(i / len(block_state.timesteps) < s or (i + 1) / len(block_state.timesteps) > e) + for s, e in zip(block_state.control_guidance_start, block_state.control_guidance_end) + ] + block_state.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + block_state.controlnet_cond = block_state.control_image + block_state.conditioning_scale = block_state.controlnet_conditioning_scale + + + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLControlNetUnionInputStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("controlnet", ControlNetUnionModel), + ComponentSpec("control_image_processor", VaeImageProcessor, config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), default_creation_method="from_config"), + ] + + @property + def description(self) -> str: + return "step that prepares inputs for the ControlNetUnion model" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("control_image", required=True), + InputParam("control_mode", required=True), + InputParam("control_guidance_start", default=0.0), + InputParam("control_guidance_end", default=1.0), + InputParam("controlnet_conditioning_scale", default=1.0), + InputParam("guess_mode", default=False), + InputParam("num_images_per_prompt", default=1), + ] + + @property + def intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Used to determine the shape of the control images. Can be generated in prepare_latent step." + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + ), + InputParam( + "dtype", + required=True, + type_hint=torch.dtype, + description="The dtype of model tensor inputs. Can be generated in input step." + ), + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Needed to determine `controlnet_keep`. Can be generated in set_timesteps step." + ), + InputParam( + "crops_coords", + type_hint=Optional[Tuple[int]], + description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [ + OutputParam("controlnet_cond", type_hint=List[torch.Tensor], description="The processed control images"), + OutputParam("control_type_idx", type_hint=List[int], description="The control mode indices", kwargs_type="controlnet_kwargs"), + OutputParam("control_type", type_hint=torch.Tensor, description="The control type tensor that specifies which control type is active", kwargs_type="controlnet_kwargs"), + OutputParam("control_guidance_start", type_hint=float, description="The controlnet guidance start value"), + OutputParam("control_guidance_end", type_hint=float, description="The controlnet guidance end value"), + OutputParam("conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values"), + OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"), + OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"), + ] + + # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image + # 1. return image without apply any guidance + # 2. add crops_coords and resize_mode to preprocess() + @staticmethod + def prepare_control_image( + components, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + crops_coords=None, + ): + if crops_coords is not None: + image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) + else: + image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + + image_batch_size = image.shape[0] + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + image = image.to(device=device, dtype=dtype) + return image + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + + block_state = self.get_block_state(state) + + controlnet = unwrap_module(components.controlnet) + + device = components._execution_device + dtype = block_state.dtype or components.controlnet.dtype + + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * components.vae_scale_factor + block_state.width = block_state.width * components.vae_scale_factor + + + # control_guidance_start/control_guidance_end (align format) + if not isinstance(block_state.control_guidance_start, list) and isinstance(block_state.control_guidance_end, list): + block_state.control_guidance_start = len(block_state.control_guidance_end) * [block_state.control_guidance_start] + elif not isinstance(block_state.control_guidance_end, list) and isinstance(block_state.control_guidance_start, list): + block_state.control_guidance_end = len(block_state.control_guidance_start) * [block_state.control_guidance_end] + + # guess_mode + block_state.global_pool_conditions = controlnet.config.global_pool_conditions + block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions + + # control_image + if not isinstance(block_state.control_image, list): + block_state.control_image = [block_state.control_image] + # control_mode + if not isinstance(block_state.control_mode, list): + block_state.control_mode = [block_state.control_mode] + + if len(block_state.control_image) != len(block_state.control_mode): + raise ValueError("Expected len(control_image) == len(control_type)") + + # control_type + block_state.num_control_type = controlnet.config.num_control_type + block_state.control_type = [0 for _ in range(block_state.num_control_type)] + for control_idx in block_state.control_mode: + block_state.control_type[control_idx] = 1 + block_state.control_type = torch.Tensor(block_state.control_type) + + block_state.control_type = block_state.control_type.reshape(1, -1).to(device, dtype=block_state.dtype) + repeat_by = block_state.batch_size * block_state.num_images_per_prompt // block_state.control_type.shape[0] + block_state.control_type = block_state.control_type.repeat_interleave(repeat_by, dim=0) + + # prepare control_image + for idx, _ in enumerate(block_state.control_image): + block_state.control_image[idx] = self.prepare_control_image( + components, + image=block_state.control_image[idx], + width=block_state.width, + height=block_state.height, + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_images_per_prompt=block_state.num_images_per_prompt, + device=device, + dtype=dtype, + crops_coords=block_state.crops_coords, + ) + block_state.height, block_state.width = block_state.control_image[idx].shape[-2:] + + # controlnet_keep + block_state.controlnet_keep = [] + for i in range(len(block_state.timesteps)): + block_state.controlnet_keep.append( + 1.0 + - float(i / len(block_state.timesteps) < block_state.control_guidance_start or (i + 1) / len(block_state.timesteps) > block_state.control_guidance_end) + ) + block_state.control_type_idx = block_state.control_mode + block_state.controlnet_cond = block_state.control_image + block_state.conditioning_scale = block_state.controlnet_conditioning_scale + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLControlNetAutoInput(AutoPipelineBlocks): + + block_classes = [StableDiffusionXLControlNetUnionInputStep, StableDiffusionXLControlNetInputStep] + block_names = ["controlnet_union", "controlnet"] + block_trigger_inputs = ["control_mode", "control_image"] + + + +# Before denoise +class StableDiffusionXLBeforeDenoiseStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLInputStep, StableDiffusionXLSetTimestepsStep, StableDiffusionXLPrepareLatentsStep, StableDiffusionXLPrepareAdditionalConditioningStep, StableDiffusionXLControlNetAutoInput] + block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond", "controlnet_input"] + + @property + def description(self): + return "Before denoise step that prepare the inputs for the denoise step.\n" + \ + "This is a sequential pipeline blocks:\n" + \ + " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ + " - `StableDiffusionXLSetTimestepsStep` is used to set the timesteps\n" + \ + " - `StableDiffusionXLPrepareLatentsStep` is used to prepare the latents\n" + \ + " - `StableDiffusionXLPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + \ + " - `StableDiffusionXLControlNetAutoInput` is used to prepare the controlnet input" + + +class StableDiffusionXLImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLInputStep, StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLImg2ImgPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, StableDiffusionXLControlNetAutoInput] + block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond", "controlnet_input"] + + @property + def description(self): + return "Before denoise step that prepare the inputs for the denoise step for img2img task.\n" + \ + "This is a sequential pipeline blocks:\n" + \ + " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ + " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" + \ + " - `StableDiffusionXLImg2ImgPrepareLatentsStep` is used to prepare the latents\n" + \ + " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + \ + " - `StableDiffusionXLControlNetAutoInput` is used to prepare the controlnet input" + + +class StableDiffusionXLInpaintBeforeDenoiseStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLInputStep, StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLInpaintPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, StableDiffusionXLControlNetAutoInput] + block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond", "controlnet_input"] + + @property + def description(self): + return "Before denoise step that prepare the inputs for the denoise step for inpainting task.\n" + \ + "This is a sequential pipeline blocks:\n" + \ + " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ + " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" + \ + " - `StableDiffusionXLInpaintPrepareLatentsStep` is used to prepare the latents\n" + \ + " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + \ + " - `StableDiffusionXLControlNetAutoInput` is used to prepare the controlnet input" + + +class StableDiffusionXLAutoBeforeDenoiseStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintBeforeDenoiseStep, StableDiffusionXLImg2ImgBeforeDenoiseStep, StableDiffusionXLBeforeDenoiseStep] + block_names = ["inpaint", "img2img", "text2img"] + block_trigger_inputs = ["mask", "image_latents", None] + + @property + def description(self): + return "Before denoise step that prepare the inputs for the denoise step.\n" + \ + "This is an auto pipeline block that works for text2img, img2img and inpainting tasks as well as controlnet, controlnet_union.\n" + \ + " - `StableDiffusionXLInpaintBeforeDenoiseStep` (inpaint) is used when both `mask` and `image_latents` are provided.\n" + \ + " - `StableDiffusionXLImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n" + \ + " - `StableDiffusionXLBeforeDenoiseStep` (text2img) is used when both `image_latents` and `mask` are not provided.\n" + \ + " - `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided.\n" + \ + " - `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided." + diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py new file mode 100644 index 000000000000..f605d0ab00aa --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py @@ -0,0 +1,1362 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from tqdm.auto import tqdm + +from ...configuration_utils import FrozenDict +from ...models import ControlNetModel, UNet2DConditionModel +from ...schedulers import EulerDiscreteScheduler +from ...utils import logging +from ...utils.torch_utils import unwrap_module + +from ...guiders import ClassifierFreeGuidance +from .modular_loader import StableDiffusionXLModularLoader +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from ..modular_pipeline import ( + PipelineBlock, + PipelineState, + AutoPipelineBlocks, + LoopSequentialPipelineBlocks, + BlockState, +) +from dataclasses import asdict + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + + +# YiYi experimenting composible denoise loop +# loop step (1): prepare latent input for denoiser +class StableDiffusionXLDenoiseLoopBeforeDenoiser(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return "step within the denoising loop that prepare the latent input for the denoiser" + + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("scaled_latents", type_hint=torch.Tensor, description="The scaled latents input for denoiser")] + + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + + block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) + + + return components, block_state + +# loop step (1): prepare latent input for denoiser (with inpainting) +class StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def description(self) -> str: + return "step within the denoising loop that prepare the latent input for the denoiser" + + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + InputParam( + "mask", + type_hint=Optional[torch.Tensor], + description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." + ), + InputParam( + "masked_image_latents", + type_hint=Optional[torch.Tensor], + description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("scaled_latents", type_hint=torch.Tensor, description="The scaled latents input for denoiser")] + + @staticmethod + def check_inputs(components, block_state): + + num_channels_unet = components.num_channels_unet + if num_channels_unet == 9: + # default case for runwayml/stable-diffusion-inpainting + if block_state.mask is None or block_state.masked_image_latents is None: + raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") + num_channels_latents = block_state.latents.shape[1] + num_channels_mask = block_state.mask.shape[1] + num_channels_masked_image = block_state.masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: + raise ValueError( + f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" + f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `components.unet` or your `mask_image` or `image` input." + ) + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + + self.check_inputs(components, block_state) + + block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) + if components.num_channels_unet == 9: + block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) + + + return components, block_state + +# loop step (2): denoise the latents with guidance +class StableDiffusionXLDenoiseLoopDenoiser(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoise the latents with guidance" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("cross_attention_kwargs"), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "scaled_latents", + required=True, + type_hint=torch.Tensor, + description="The prepared latents input to use for the denoiser. Can be generated in latent step within the denoise loop." + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + "timestep_cond", + type_hint=Optional[torch.Tensor], + description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." + ), + InputParam( + kwargs_type="guider_input_fields", + description=( + "All conditional model inputs that need to be prepared with guider. " + "It should contain prompt_embeds/negative_prompt_embeds, " + "add_time_ids/negative_add_time_ids, " + "pooled_prompt_embeds/negative_pooled_prompt_embeds, " + "and ip_adapter_embeds/negative_ip_adapter_embeds (optional)." + "please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" + ) + ), + + ] + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int) -> PipelineState: + + # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) + # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) + guider_input_fields ={ + "prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"), + "time_ids": ("add_time_ids", "negative_add_time_ids"), + "text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), + "image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"), + } + + + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + + # Prepare mini‐batches according to guidance method and `guider_input_fields` + # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds. + # e.g. for CFG, we prepare two batches: one for uncond, one for cond + # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds + # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds + guider_state = components.guider.prepare_inputs(block_state, guider_input_fields) + + # run the denoiser for each guidance batch + for guider_state_batch in guider_state: + components.guider.prepare_models(components.unet) + cond_kwargs = guider_state_batch.as_dict() + cond_kwargs = {k:v for k,v in cond_kwargs.items() if k in guider_input_fields} + prompt_embeds = cond_kwargs.pop("prompt_embeds") + + # Predict the noise residual + # store the noise_pred in guider_state_batch so that we can apply guidance across all batches + guider_state_batch.noise_pred = components.unet( + block_state.scaled_latents, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=block_state.timestep_cond, + cross_attention_kwargs=block_state.cross_attention_kwargs, + added_cond_kwargs=cond_kwargs, + return_dict=False, + )[0] + components.guider.cleanup_models(components.unet) + + # Perform guidance + block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state) + + return components, block_state + +# loop step (2): denoise the latents with guidance (with controlnet) +class StableDiffusionXLControlNetDenoiseLoopDenoiser(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ComponentSpec("unet", UNet2DConditionModel), + ComponentSpec("controlnet", ControlNetModel), + ] + + @property + def description(self) -> str: + return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("cross_attention_kwargs"), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "controlnet_cond", + required=True, + type_hint=torch.Tensor, + description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "conditioning_scale", + type_hint=float, + description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "guess_mode", + required=True, + type_hint=bool, + description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "controlnet_keep", + required=True, + type_hint=List[float], + description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "scaled_latents", + required=True, + type_hint=torch.Tensor, + description="The prepared latents input to use for the denoiser. Can be generated in latent step within the denoise loop." + ), + InputParam( + "timestep_cond", + type_hint=Optional[torch.Tensor], + description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + kwargs_type="guider_input_fields", + description=( + "All conditional model inputs that need to be prepared with guider. " + "It should contain prompt_embeds/negative_prompt_embeds, " + "add_time_ids/negative_add_time_ids, " + "pooled_prompt_embeds/negative_pooled_prompt_embeds, " + "and ip_adapter_embeds/negative_ip_adapter_embeds (optional)." + "please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" + ) + ), + InputParam( + kwargs_type="controlnet_kwargs", + description=( + "additional kwargs for controlnet (e.g. control_type_idx and control_type from the controlnet union input step )" + "please add `kwargs_type=controlnet_kwargs` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" + ) + ) + ] + + @staticmethod + def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + + accepted_kwargs = set(inspect.signature(func).parameters.keys()) + extra_kwargs = {} + for key, value in kwargs.items(): + if key in accepted_kwargs and key not in exclude_kwargs: + extra_kwargs[key] = value + + return extra_kwargs + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + + extra_controlnet_kwargs = self.prepare_extra_kwargs(components.controlnet.forward, **block_state.controlnet_kwargs) + + # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) + # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) + guider_input_fields ={ + "prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"), + "time_ids": ("add_time_ids", "negative_add_time_ids"), + "text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), + "image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"), + } + + + # cond_scale for the timestep (controlnet input) + if isinstance(block_state.controlnet_keep[i], list): + block_state.cond_scale = [c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i])] + else: + controlnet_cond_scale = block_state.conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + block_state.cond_scale = controlnet_cond_scale * block_state.controlnet_keep[i] + + # default controlnet output/unet input for guess mode + conditional path + block_state.down_block_res_samples_zeros = None + block_state.mid_block_res_sample_zeros = None + + # guided denoiser step + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + + # Prepare mini‐batches according to guidance method and `guider_input_fields` + # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds. + # e.g. for CFG, we prepare two batches: one for uncond, one for cond + # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds + # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds + guider_state = components.guider.prepare_inputs(block_state, guider_input_fields) + + # run the denoiser for each guidance batch + for guider_state_batch in guider_state: + components.guider.prepare_models(components.unet) + + # Prepare additional conditionings + added_cond_kwargs = { + "text_embeds": guider_state_batch.text_embeds, + "time_ids": guider_state_batch.time_ids, + } + if hasattr(guider_state_batch, "image_embeds") and guider_state_batch.image_embeds is not None: + added_cond_kwargs["image_embeds"] = guider_state_batch.image_embeds + + # Prepare controlnet additional conditionings + controlnet_added_cond_kwargs = { + "text_embeds": guider_state_batch.text_embeds, + "time_ids": guider_state_batch.time_ids, + } + # run controlnet for the guidance batch + if block_state.guess_mode and not components.guider.is_conditional: + # guider always run uncond batch first, so these tensors should be set already + down_block_res_samples = block_state.down_block_res_samples_zeros + mid_block_res_sample = block_state.mid_block_res_sample_zeros + else: + down_block_res_samples, mid_block_res_sample = components.controlnet( + block_state.scaled_latents, + t, + encoder_hidden_states=guider_state_batch.prompt_embeds, + controlnet_cond=block_state.controlnet_cond, + conditioning_scale=block_state.cond_scale, + guess_mode=block_state.guess_mode, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + **extra_controlnet_kwargs, + ) + + # assign it to block_state so it will be available for the uncond guidance batch + if block_state.down_block_res_samples_zeros is None: + block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in down_block_res_samples] + if block_state.mid_block_res_sample_zeros is None: + block_state.mid_block_res_sample_zeros = torch.zeros_like(mid_block_res_sample) + + # Predict the noise + # store the noise_pred in guider_state_batch so we can apply guidance across all batches + guider_state_batch.noise_pred = components.unet( + block_state.scaled_latents, + t, + encoder_hidden_states=guider_state_batch.prompt_embeds, + timestep_cond=block_state.timestep_cond, + cross_attention_kwargs=block_state.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + return_dict=False, + )[0] + components.guider.cleanup_models(components.unet) + + # Perform guidance + block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state) + + return components, block_state + +# loop step (3): scheduler step to update latents +class StableDiffusionXLDenoiseLoopAfterDenoiser(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("generator"), + InputParam("eta", default=0.0), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + #YiYi TODO: move this out of here + @staticmethod + def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + + accepted_kwargs = set(inspect.signature(func).parameters.keys()) + extra_kwargs = {} + for key, value in kwargs.items(): + if key in accepted_kwargs and key not in exclude_kwargs: + extra_kwargs[key] = value + + return extra_kwargs + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + + # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) + + + # Perform scheduler step using the predicted output + block_state.latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **block_state.scheduler_step_kwargs, return_dict=False)[0] + + if block_state.latents.dtype != block_state.latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + block_state.latents = block_state.latents.to(block_state.latents_dtype) + + return components, block_state + +# loop step (3): scheduler step to update latents (with inpainting) +class StableDiffusionXLInpaintDenoiseLoopAfterDenoiser(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def description(self) -> str: + return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("generator"), + InputParam("eta", default=0.0), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + "mask", + type_hint=Optional[torch.Tensor], + description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." + ), + InputParam( + "noise", + type_hint=Optional[torch.Tensor], + description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." + ), + InputParam( + "image_latents", + type_hint=Optional[torch.Tensor], + description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + @staticmethod + def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + + accepted_kwargs = set(inspect.signature(func).parameters.keys()) + extra_kwargs = {} + for key, value in kwargs.items(): + if key in accepted_kwargs and key not in exclude_kwargs: + extra_kwargs[key] = value + + return extra_kwargs + + def check_inputs(self, components, block_state): + if components.num_channels_unet == 4: + if block_state.image_latents is None: + raise ValueError(f"image_latents is required for this step {self.__class__.__name__}") + if block_state.mask is None: + raise ValueError(f"mask is required for this step {self.__class__.__name__}") + if block_state.noise is None: + raise ValueError(f"noise is required for this step {self.__class__.__name__}") + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + + self.check_inputs(components, block_state) + + # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) + + + # Perform scheduler step using the predicted output + block_state.latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **block_state.scheduler_step_kwargs, return_dict=False)[0] + + if block_state.latents.dtype != block_state.latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + block_state.latents = block_state.latents.to(block_state.latents_dtype) + + # adjust latent for inpainting + if components.num_channels_unet == 4: + block_state.init_latents_proper = block_state.image_latents + if i < len(block_state.timesteps) - 1: + block_state.noise_timestep = block_state.timesteps[i + 1] + block_state.init_latents_proper = components.scheduler.add_noise( + block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) + ) + + block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents + + + + return components, block_state + + +# the loop wrapper that iterates over the timesteps +class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks): + + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return ( + "Step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process" + ) + + @property + def loop_expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def loop_intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." + ), + ] + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False + if block_state.disable_guidance: + components.guider.disable() + else: + components.guider.enable() + + block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) + + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components, block_state = self.loop_step(components, block_state, i=i, t=t) + if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): + progress_bar.update() + + self.add_block_state(state, block_state) + + return components, state + + +# composing the denoising loops +class StableDiffusionXLDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [StableDiffusionXLDenoiseLoopBeforeDenoiser, StableDiffusionXLDenoiseLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + +# control_cond +class StableDiffusionXLControlNetDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [StableDiffusionXLDenoiseLoopBeforeDenoiser, StableDiffusionXLControlNetDenoiseLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + +# mask +class StableDiffusionXLInpaintDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser, StableDiffusionXLDenoiseLoopDenoiser, StableDiffusionXLInpaintDenoiseLoopAfterDenoiser] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + +# control_cond + mask +class StableDiffusionXLInpaintControlNetDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser, StableDiffusionXLControlNetDenoiseLoopDenoiser, StableDiffusionXLInpaintDenoiseLoopAfterDenoiser] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + + +# all task without controlnet +class StableDiffusionXLDenoiseStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintDenoiseLoop, StableDiffusionXLDenoiseLoop] + block_names = ["inpaint_denoise", "denoise"] + block_trigger_inputs = ["mask", None] + +# all task with controlnet +class StableDiffusionXLControlNetDenoiseStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintControlNetDenoiseLoop, StableDiffusionXLControlNetDenoiseLoop] + block_names = ["inpaint_controlnet_denoise", "controlnet_denoise"] + block_trigger_inputs = ["mask", None] + +# all task with or without controlnet +class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLControlNetDenoiseStep, StableDiffusionXLDenoiseStep] + block_names = ["controlnet_denoise", "denoise"] + block_trigger_inputs = ["controlnet_cond", None] + + + + + + + +# YiYi Notes: alternatively, this is you can just write the denoise loop using a pipeline block, easier but not composible +# class StableDiffusionXLDenoiseStep(PipelineBlock): + +# model_name = "stable-diffusion-xl" + +# @property +# def expected_components(self) -> List[ComponentSpec]: +# return [ +# ComponentSpec( +# "guider", +# ClassifierFreeGuidance, +# config=FrozenDict({"guidance_scale": 7.5}), +# default_creation_method="from_config"), +# ComponentSpec("scheduler", EulerDiscreteScheduler), +# ComponentSpec("unet", UNet2DConditionModel), +# ] + +# @property +# def description(self) -> str: +# return ( +# "Step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process" +# ) + +# @property +# def inputs(self) -> List[Tuple[str, Any]]: +# return [ +# InputParam("cross_attention_kwargs"), +# InputParam("generator"), +# InputParam("eta", default=0.0), +# InputParam("num_images_per_prompt", default=1), +# ] + +# @property +# def intermediates_inputs(self) -> List[str]: +# return [ +# InputParam( +# "latents", +# required=True, +# type_hint=torch.Tensor, +# description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." +# ), +# InputParam( +# "batch_size", +# required=True, +# type_hint=int, +# description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." +# ), +# InputParam( +# "timesteps", +# required=True, +# type_hint=torch.Tensor, +# description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." +# ), +# InputParam( +# "num_inference_steps", +# required=True, +# type_hint=int, +# description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." +# ), +# InputParam( +# "pooled_prompt_embeds", +# required=True, +# type_hint=torch.Tensor, +# description="The pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "negative_pooled_prompt_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " +# ), +# InputParam( +# "add_time_ids", +# required=True, +# type_hint=torch.Tensor, +# description="The time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." +# ), +# InputParam( +# "negative_add_time_ids", +# type_hint=Optional[torch.Tensor], +# description="The negative time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." +# ), +# InputParam( +# "prompt_embeds", +# required=True, +# type_hint=torch.Tensor, +# description="The prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "negative_prompt_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " +# ), +# InputParam( +# "timestep_cond", +# type_hint=Optional[torch.Tensor], +# description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." +# ), +# InputParam( +# "mask", +# type_hint=Optional[torch.Tensor], +# description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "masked_image_latents", +# type_hint=Optional[torch.Tensor], +# description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "noise", +# type_hint=Optional[torch.Tensor], +# description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." +# ), +# InputParam( +# "image_latents", +# type_hint=Optional[torch.Tensor], +# description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "ip_adapter_embeds", +# type_hint=Optional[torch.Tensor], +# description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." +# ), +# InputParam( +# "negative_ip_adapter_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." +# ), +# ] + +# @property +# def intermediates_outputs(self) -> List[OutputParam]: +# return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + +# @staticmethod +# def check_inputs(components, block_state): + +# num_channels_unet = components.unet.config.in_channels +# if num_channels_unet == 9: +# # default case for runwayml/stable-diffusion-inpainting +# if block_state.mask is None or block_state.masked_image_latents is None: +# raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") +# num_channels_latents = block_state.latents.shape[1] +# num_channels_mask = block_state.mask.shape[1] +# num_channels_masked_image = block_state.masked_image_latents.shape[1] +# if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: +# raise ValueError( +# f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" +# f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" +# f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" +# f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" +# " `components.unet` or your `mask_image` or `image` input." +# ) + +# # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components +# @staticmethod +# def prepare_extra_step_kwargs(components, generator, eta): +# # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature +# # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. +# # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 +# # and should be between [0, 1] + +# accepts_eta = "eta" in set(inspect.signature(components.scheduler.step).parameters.keys()) +# extra_step_kwargs = {} +# if accepts_eta: +# extra_step_kwargs["eta"] = eta + +# # check if the scheduler accepts generator +# accepts_generator = "generator" in set(inspect.signature(components.scheduler.step).parameters.keys()) +# if accepts_generator: +# extra_step_kwargs["generator"] = generator +# return extra_step_kwargs + +# @torch.no_grad() +# def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + +# block_state = self.get_block_state(state) +# self.check_inputs(components, block_state) + +# block_state.num_channels_unet = components.unet.config.in_channels +# block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False +# if block_state.disable_guidance: +# components.guider.disable() +# else: +# components.guider.enable() + +# # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline +# block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) +# block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) + +# components.guider.set_input_fields( +# prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), +# add_time_ids=("add_time_ids", "negative_add_time_ids"), +# pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), +# ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), +# ) + +# with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: +# for i, t in enumerate(block_state.timesteps): +# components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) +# guider_data = components.guider.prepare_inputs(block_state) + +# block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) + +# # Prepare for inpainting +# if block_state.num_channels_unet == 9: +# block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) + +# for batch in guider_data: +# components.guider.prepare_models(components.unet) + +# # Prepare additional conditionings +# batch.added_cond_kwargs = { +# "text_embeds": batch.pooled_prompt_embeds, +# "time_ids": batch.add_time_ids, +# } +# if batch.ip_adapter_embeds is not None: +# batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds + +# # Predict the noise residual +# batch.noise_pred = components.unet( +# block_state.scaled_latents, +# t, +# encoder_hidden_states=batch.prompt_embeds, +# timestep_cond=block_state.timestep_cond, +# cross_attention_kwargs=block_state.cross_attention_kwargs, +# added_cond_kwargs=batch.added_cond_kwargs, +# return_dict=False, +# )[0] +# components.guider.cleanup_models(components.unet) + +# # Perform guidance +# block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_data) + +# # Perform scheduler step using the predicted output +# block_state.latents_dtype = block_state.latents.dtype +# block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] + +# if block_state.latents.dtype != block_state.latents_dtype: +# if torch.backends.mps.is_available(): +# # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 +# block_state.latents = block_state.latents.to(block_state.latents_dtype) + +# if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: +# block_state.init_latents_proper = block_state.image_latents +# if i < len(block_state.timesteps) - 1: +# block_state.noise_timestep = block_state.timesteps[i + 1] +# block_state.init_latents_proper = components.scheduler.add_noise( +# block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) +# ) + +# block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents + +# if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): +# progress_bar.update() + +# self.add_block_state(state, block_state) + +# return components, state + + + +# class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): + +# model_name = "stable-diffusion-xl" + +# @property +# def expected_components(self) -> List[ComponentSpec]: +# return [ +# ComponentSpec( +# "guider", +# ClassifierFreeGuidance, +# config=FrozenDict({"guidance_scale": 7.5}), +# default_creation_method="from_config"), +# ComponentSpec("scheduler", EulerDiscreteScheduler), +# ComponentSpec("unet", UNet2DConditionModel), +# ComponentSpec("controlnet", ControlNetModel), +# ] + +# @property +# def description(self) -> str: +# return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" + +# @property +# def inputs(self) -> List[Tuple[str, Any]]: +# return [ +# InputParam("num_images_per_prompt", default=1), +# InputParam("cross_attention_kwargs"), +# InputParam("generator"), +# InputParam("eta", default=0.0), +# InputParam("controlnet_conditioning_scale", type_hint=float, default=1.0), # can expect either input or intermediate input, (intermediate input if both are passed) +# ] + +# @property +# def intermediates_inputs(self) -> List[str]: +# return [ +# InputParam( +# "controlnet_cond", +# required=True, +# type_hint=torch.Tensor, +# description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "control_guidance_start", +# required=True, +# type_hint=float, +# description="The control guidance start value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "control_guidance_end", +# required=True, +# type_hint=float, +# description="The control guidance end value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "conditioning_scale", +# type_hint=float, +# description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "guess_mode", +# required=True, +# type_hint=bool, +# description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "controlnet_keep", +# required=True, +# type_hint=List[float], +# description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "latents", +# required=True, +# type_hint=torch.Tensor, +# description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." +# ), +# InputParam( +# "batch_size", +# required=True, +# type_hint=int, +# description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." +# ), +# InputParam( +# "timesteps", +# required=True, +# type_hint=torch.Tensor, +# description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." +# ), +# InputParam( +# "prompt_embeds", +# required=True, +# type_hint=torch.Tensor, +# description="The prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "negative_prompt_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "add_time_ids", +# required=True, +# type_hint=torch.Tensor, +# description="The time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." +# ), +# InputParam( +# "negative_add_time_ids", +# type_hint=Optional[torch.Tensor], +# description="The negative time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." +# ), +# InputParam( +# "pooled_prompt_embeds", +# required=True, +# type_hint=torch.Tensor, +# description="The pooled prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "negative_pooled_prompt_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "timestep_cond", +# type_hint=Optional[torch.Tensor], +# description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" +# ), +# InputParam( +# "mask", +# type_hint=Optional[torch.Tensor], +# description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "masked_image_latents", +# type_hint=Optional[torch.Tensor], +# description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "noise", +# type_hint=Optional[torch.Tensor], +# description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." +# ), +# InputParam( +# "image_latents", +# type_hint=Optional[torch.Tensor], +# description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "crops_coords", +# type_hint=Optional[Tuple[int]], +# description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." +# ), +# InputParam( +# "ip_adapter_embeds", +# type_hint=Optional[torch.Tensor], +# description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." +# ), +# InputParam( +# "negative_ip_adapter_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." +# ), +# InputParam( +# "num_inference_steps", +# required=True, +# type_hint=int, +# description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." +# ), +# InputParam(kwargs_type="controlnet_kwargs", description="additional kwargs for controlnet") +# ] + +# @property +# def intermediates_outputs(self) -> List[OutputParam]: +# return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + +# @staticmethod +# def check_inputs(components, block_state): + +# num_channels_unet = components.unet.config.in_channels +# if num_channels_unet == 9: +# # default case for runwayml/stable-diffusion-inpainting +# if block_state.mask is None or block_state.masked_image_latents is None: +# raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") +# num_channels_latents = block_state.latents.shape[1] +# num_channels_mask = block_state.mask.shape[1] +# num_channels_masked_image = block_state.masked_image_latents.shape[1] +# if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: +# raise ValueError( +# f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" +# f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" +# f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" +# f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" +# " `components.unet` or your `mask_image` or `image` input." +# ) +# @staticmethod +# def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + +# accepted_kwargs = set(inspect.signature(func).parameters.keys()) +# extra_kwargs = {} +# for key, value in kwargs.items(): +# if key in accepted_kwargs and key not in exclude_kwargs: +# extra_kwargs[key] = value + +# return extra_kwargs + + +# @torch.no_grad() +# def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + +# block_state = self.get_block_state(state) +# self.check_inputs(components, block_state) +# block_state.device = components._execution_device +# print(f" block_state: {block_state}") + +# controlnet = unwrap_module(components.controlnet) + +# # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline +# block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) +# block_state.extra_controlnet_kwargs = self.prepare_extra_kwargs(controlnet.forward, exclude_kwargs=["controlnet_cond", "conditioning_scale", "guess_mode"], **block_state.controlnet_kwargs) + +# block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) + +# # (1) setup guider +# # disable for LCMs +# block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False +# if block_state.disable_guidance: +# components.guider.disable() +# else: +# components.guider.enable() +# components.guider.set_input_fields( +# prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), +# add_time_ids=("add_time_ids", "negative_add_time_ids"), +# pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), +# ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), +# ) + +# # (5) Denoise loop +# with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: +# for i, t in enumerate(block_state.timesteps): + +# # prepare latent input for unet +# block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) +# # adjust latent input for inpainting +# block_state.num_channels_unet = components.unet.config.in_channels +# if block_state.num_channels_unet == 9: +# block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) + + +# # cond_scale (controlnet input) +# if isinstance(block_state.controlnet_keep[i], list): +# block_state.cond_scale = [c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i])] +# else: +# block_state.controlnet_cond_scale = block_state.conditioning_scale +# if isinstance(block_state.controlnet_cond_scale, list): +# block_state.controlnet_cond_scale = block_state.controlnet_cond_scale[0] +# block_state.cond_scale = block_state.controlnet_cond_scale * block_state.controlnet_keep[i] + +# # default controlnet output/unet input for guess mode + conditional path +# block_state.down_block_res_samples_zeros = None +# block_state.mid_block_res_sample_zeros = None + +# # guided denoiser step +# components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) +# guider_state = components.guider.prepare_inputs(block_state) + +# for guider_state_batch in guider_state: +# components.guider.prepare_models(components.unet) + +# # Prepare additional conditionings +# guider_state_batch.added_cond_kwargs = { +# "text_embeds": guider_state_batch.pooled_prompt_embeds, +# "time_ids": guider_state_batch.add_time_ids, +# } +# if guider_state_batch.ip_adapter_embeds is not None: +# guider_state_batch.added_cond_kwargs["image_embeds"] = guider_state_batch.ip_adapter_embeds + +# # Prepare controlnet additional conditionings +# guider_state_batch.controlnet_added_cond_kwargs = { +# "text_embeds": guider_state_batch.pooled_prompt_embeds, +# "time_ids": guider_state_batch.add_time_ids, +# } + +# if block_state.guess_mode and not components.guider.is_conditional: +# # guider always run uncond batch first, so these tensors should be set already +# guider_state_batch.down_block_res_samples = block_state.down_block_res_samples_zeros +# guider_state_batch.mid_block_res_sample = block_state.mid_block_res_sample_zeros +# else: +# guider_state_batch.down_block_res_samples, guider_state_batch.mid_block_res_sample = components.controlnet( +# block_state.scaled_latents, +# t, +# encoder_hidden_states=guider_state_batch.prompt_embeds, +# controlnet_cond=block_state.controlnet_cond, +# conditioning_scale=block_state.conditioning_scale, +# guess_mode=block_state.guess_mode, +# added_cond_kwargs=guider_state_batch.controlnet_added_cond_kwargs, +# return_dict=False, +# **block_state.extra_controlnet_kwargs, +# ) + +# if block_state.down_block_res_samples_zeros is None: +# block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in guider_state_batch.down_block_res_samples] +# if block_state.mid_block_res_sample_zeros is None: +# block_state.mid_block_res_sample_zeros = torch.zeros_like(guider_state_batch.mid_block_res_sample) + + + +# guider_state_batch.noise_pred = components.unet( +# block_state.scaled_latents, +# t, +# encoder_hidden_states=guider_state_batch.prompt_embeds, +# timestep_cond=block_state.timestep_cond, +# cross_attention_kwargs=block_state.cross_attention_kwargs, +# added_cond_kwargs=guider_state_batch.added_cond_kwargs, +# down_block_additional_residuals=guider_state_batch.down_block_res_samples, +# mid_block_additional_residual=guider_state_batch.mid_block_res_sample, +# return_dict=False, +# )[0] +# components.guider.cleanup_models(components.unet) + +# # Perform guidance +# block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_state) + +# # Perform scheduler step using the predicted output +# block_state.latents_dtype = block_state.latents.dtype +# block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] + +# if block_state.latents.dtype != block_state.latents_dtype: +# if torch.backends.mps.is_available(): +# # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 +# block_state.latents = block_state.latents.to(block_state.latents_dtype) + +# # adjust latent for inpainting +# if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: +# block_state.init_latents_proper = block_state.image_latents +# if i < len(block_state.timesteps) - 1: +# block_state.noise_timestep = block_state.timesteps[i + 1] +# block_state.init_latents_proper = components.scheduler.add_noise( +# block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) +# ) + +# block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents + +# if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): +# progress_bar.update() + +# self.add_block_state(state, block_state) + +# return components, state \ No newline at end of file diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py new file mode 100644 index 000000000000..3c84fc71c8af --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py @@ -0,0 +1,856 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, List, Optional, Tuple, Union, Dict + +import PIL +import torch +from collections import OrderedDict + +from ...image_processor import VaeImageProcessor, PipelineImageInput +from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin +from ...models import ControlNetModel, ImageProjection, UNet2DConditionModel, AutoencoderKL, ControlNetUnionModel +from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor +from ...models.lora import adjust_lora_scale_text_encoder +from ...utils import ( + USE_PEFT_BACKEND, + logging, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor, unwrap_module +from ...pipelines.controlnet.multicontrolnet import MultiControlNetModel +from ...configuration_utils import FrozenDict + +from transformers import ( + CLIPTextModel, + CLIPImageProcessor, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...schedulers import EulerDiscreteScheduler +from ...guiders import ClassifierFreeGuidance + +from .modular_loader import StableDiffusionXLModularLoader +from ..modular_pipeline import PipelineBlock, PipelineState, AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam, ConfigSpec + +import numpy as np + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class StableDiffusionXLIPAdapterStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + + @property + def description(self) -> str: + return ( + "IP Adapter step that handles all the ip adapter related tasks: Load/unload ip adapter weights into unet, prepare ip adapter image embeddings, etc" + " See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)" + " for more details" + ) + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("image_encoder", CLIPVisionModelWithProjection), + ComponentSpec("feature_extractor", CLIPImageProcessor, config=FrozenDict({"size": 224, "crop_size": 224}), default_creation_method="from_config"), + ComponentSpec("unet", UNet2DConditionModel), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + "ip_adapter_image", + PipelineImageInput, + required=True, + description="The image(s) to be used as ip adapter" + ) + ] + + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [ + OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"), + OutputParam("negative_ip_adapter_embeds", type_hint=torch.Tensor, description="Negative IP adapter image embeddings") + ] + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image with self -> components + @staticmethod + def encode_image(components, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(components.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = components.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = components.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = components.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = components.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, prepare_unconditional_embeds + ): + image_embeds = [] + if prepare_unconditional_embeds: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, components.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + components, single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if prepare_unconditional_embeds: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if prepare_unconditional_embeds: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if prepare_unconditional_embeds: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 + block_state.device = components._execution_device + + block_state.ip_adapter_embeds = self.prepare_ip_adapter_image_embeds( + components, + ip_adapter_image=block_state.ip_adapter_image, + ip_adapter_image_embeds=None, + device=block_state.device, + num_images_per_prompt=1, + prepare_unconditional_embeds=block_state.prepare_unconditional_embeds, + ) + if block_state.prepare_unconditional_embeds: + block_state.negative_ip_adapter_embeds = [] + for i, image_embeds in enumerate(block_state.ip_adapter_embeds): + negative_image_embeds, image_embeds = image_embeds.chunk(2) + block_state.negative_ip_adapter_embeds.append(negative_image_embeds) + block_state.ip_adapter_embeds[i] = image_embeds + + self.add_block_state(state, block_state) + return components, state + + +class StableDiffusionXLTextEncoderStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return( + "Text Encoder step that generate text_embeddings to guide the image generation" + ) + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("text_encoder", CLIPTextModel), + ComponentSpec("text_encoder_2", CLIPTextModelWithProjection), + ComponentSpec("tokenizer", CLIPTokenizer), + ComponentSpec("tokenizer_2", CLIPTokenizer), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ] + + @property + def expected_configs(self) -> List[ConfigSpec]: + return [ConfigSpec("force_zeros_for_empty_prompt", True)] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("prompt"), + InputParam("prompt_2"), + InputParam("negative_prompt"), + InputParam("negative_prompt_2"), + InputParam("cross_attention_kwargs"), + InputParam("clip_skip"), + ] + + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [ + OutputParam("prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields",description="text embeddings used to guide the image generation"), + OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative text embeddings used to guide the image generation"), + OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="pooled text embeddings used to guide the image generation"), + OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative pooled text embeddings used to guide the image generation"), + ] + + @staticmethod + def check_inputs(block_state): + + if block_state.prompt is not None and (not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}") + elif block_state.prompt_2 is not None and (not isinstance(block_state.prompt_2, str) and not isinstance(block_state.prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(block_state.prompt_2)}") + + @staticmethod + def encode_prompt( + components, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prepare_unconditional_embeds: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prepare_unconditional_embeds (`bool`): + whether to use prepare unconditional embeddings or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or components._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(components, StableDiffusionXLLoraLoaderMixin): + components._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if components.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(components.text_encoder, lora_scale) + else: + scale_lora_layers(components.text_encoder, lora_scale) + + if components.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(components.text_encoder_2, lora_scale) + else: + scale_lora_layers(components.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [components.tokenizer, components.tokenizer_2] if components.tokenizer is not None else [components.tokenizer_2] + text_encoders = ( + [components.text_encoder, components.text_encoder_2] if components.text_encoder is not None else [components.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(components, TextualInversionLoaderMixin): + prompt = components.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and components.config.force_zeros_for_empty_prompt + if prepare_unconditional_embeds and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif prepare_unconditional_embeds and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(components, TextualInversionLoaderMixin): + negative_prompt = components.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if components.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=components.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if prepare_unconditional_embeds: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if components.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if prepare_unconditional_embeds: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if components.text_encoder is not None: + if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(components.text_encoder, lora_scale) + + if components.text_encoder_2 is not None: + if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(components.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + # Get inputs and intermediates + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 + block_state.device = components._execution_device + + # Encode input prompt + block_state.text_encoder_lora_scale = ( + block_state.cross_attention_kwargs.get("scale", None) if block_state.cross_attention_kwargs is not None else None + ) + ( + block_state.prompt_embeds, + block_state.negative_prompt_embeds, + block_state.pooled_prompt_embeds, + block_state.negative_pooled_prompt_embeds, + ) = self.encode_prompt( + components, + block_state.prompt, + block_state.prompt_2, + block_state.device, + 1, + block_state.prepare_unconditional_embeds, + block_state.negative_prompt, + block_state.negative_prompt_2, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + lora_scale=block_state.text_encoder_lora_scale, + clip_skip=block_state.clip_skip, + ) + # Add outputs + self.add_block_state(state, block_state) + return components, state + + +class StableDiffusionXLVaeEncoderStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + + @property + def description(self) -> str: + return ( + "Vae Encoder step that encode the input image into a latent representation" + ) + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config"), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("image", required=True), + InputParam("generator"), + InputParam("height"), + InputParam("width"), + ] + + @property + def intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), + InputParam("preprocess_kwargs", type_hint=Optional[dict], description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]")] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation")] + + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components + # YiYi TODO: update the _encode_vae_image so that we can use #Coped from + def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator): + + latents_mean = latents_std = None + if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: + latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: + latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) + + dtype = image.dtype + if components.vae.config.force_upcast: + image = image.float() + components.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(components.vae.encode(image), generator=generator) + + if components.vae.config.force_upcast: + components.vae.to(dtype) + + image_latents = image_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) + latents_std = latents_std.to(device=image_latents.device, dtype=dtype) + image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std + else: + image_latents = components.vae.config.scaling_factor * image_latents + + return image_latents + + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.preprocess_kwargs = block_state.preprocess_kwargs or {} + block_state.device = components._execution_device + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + + block_state.image = components.image_processor.preprocess(block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs) + block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype) + + block_state.batch_size = block_state.image.shape[0] + + # if generator is a list, make sure the length of it matches the length of images (both should be batch_size) + if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch" + f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators." + ) + + + block_state.image_latents = self._encode_vae_image(components, image=block_state.image, generator=block_state.generator) + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config"), + ComponentSpec( + "mask_processor", + VaeImageProcessor, + config=FrozenDict({"do_normalize": False, "vae_scale_factor": 8, "do_binarize": True, "do_convert_grayscale": True}), + default_creation_method="from_config"), + ] + + + @property + def description(self) -> str: + return ( + "Vae encoder step that prepares the image and mask for the inpainting process" + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("height"), + InputParam("width"), + InputParam("generator"), + InputParam("image", required=True), + InputParam("mask_image", required=True), + InputParam("padding_mask_crop"), + ] + + @property + def intermediates_inputs(self) -> List[InputParam]: + return [InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs")] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representation of the input image"), + OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"), + OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)"), + OutputParam("crops_coords", type_hint=Optional[Tuple[int, int]], description="The crop coordinates to use for the preprocess/postprocess of the image and mask")] + + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components + # YiYi TODO: update the _encode_vae_image so that we can use #Coped from + def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator): + + latents_mean = latents_std = None + if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: + latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: + latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) + + dtype = image.dtype + if components.vae.config.force_upcast: + image = image.float() + components.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(components.vae.encode(image), generator=generator) + + if components.vae.config.force_upcast: + components.vae.to(dtype) + + image_latents = image_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) + latents_std = latents_std.to(device=image_latents.device, dtype=dtype) + image_latents = (image_latents - latents_mean) * self.vae.config.scaling_factor / latents_std + else: + image_latents = components.vae.config.scaling_factor * image_latents + + return image_latents + + # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents + # do not accept do_classifier_free_guidance + def prepare_mask_latents( + self, components, mask, masked_image, batch_size, height, width, dtype, device, generator + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + + if masked_image is not None and masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + masked_image_latents = None + + if masked_image is not None: + if masked_image_latents is None: + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator) + + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat( + batch_size // masked_image_latents.shape[0], 1, 1, 1 + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + return mask, masked_image_latents + + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + + block_state = self.get_block_state(state) + + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + block_state.device = components._execution_device + + if block_state.padding_mask_crop is not None: + block_state.crops_coords = components.mask_processor.get_crop_region(block_state.mask_image, block_state.width, block_state.height, pad=block_state.padding_mask_crop) + block_state.resize_mode = "fill" + else: + block_state.crops_coords = None + block_state.resize_mode = "default" + + block_state.image = components.image_processor.preprocess(block_state.image, height=block_state.height, width=block_state.width, crops_coords=block_state.crops_coords, resize_mode=block_state.resize_mode) + block_state.image = block_state.image.to(dtype=torch.float32) + + block_state.mask = components.mask_processor.preprocess(block_state.mask_image, height=block_state.height, width=block_state.width, resize_mode=block_state.resize_mode, crops_coords=block_state.crops_coords) + block_state.masked_image = block_state.image * (block_state.mask < 0.5) + + block_state.batch_size = block_state.image.shape[0] + block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype) + block_state.image_latents = self._encode_vae_image(components, image=block_state.image, generator=block_state.generator) + + # 7. Prepare mask latent variables + block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents( + components, + block_state.mask, + block_state.masked_image, + block_state.batch_size, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + ) + + self.add_block_state(state, block_state) + + + return components, state + + + +# auto blocks (YiYi TODO: maybe move all the auto blocks to a separate file) +# Encode +class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintVaeEncoderStep, StableDiffusionXLVaeEncoderStep] + block_names = ["inpaint", "img2img"] + block_trigger_inputs = ["mask_image", "image"] + + @property + def description(self): + return "Vae encoder step that encode the image inputs into their latent representations.\n" + \ + "This is an auto pipeline block that works for both inpainting and img2img tasks.\n" + \ + " - `StableDiffusionXLInpaintVaeEncoderStep` (inpaint) is used when both `mask_image` and `image` are provided.\n" + \ + " - `StableDiffusionXLVaeEncoderStep` (img2img) is used when only `image` is provided." + + +class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks, ModularIPAdapterMixin): + block_classes = [StableDiffusionXLIPAdapterStep] + block_names = ["ip_adapter"] + block_trigger_inputs = ["ip_adapter_image"] + + @property + def description(self): + return "Run IP Adapter step if `ip_adapter_image` is provided." + diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py new file mode 100644 index 000000000000..53f27571092a --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py @@ -0,0 +1,175 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, Optional, Tuple, Union, Dict +import PIL +import torch +import numpy as np + +from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin +from ...image_processor import PipelineImageInput +from ...pipelines.pipeline_utils import StableDiffusionMixin +from ...pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from ...utils import logging + +from ..modular_pipeline import ModularLoader +from ..modular_pipeline_utils import InputParam, OutputParam + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + + +# YiYi TODO: move to a different file? stable_diffusion_xl_module should have its own folder? +# YiYi Notes: model specific components: +## (1) it should inherit from ModularLoader +## (2) acts like a container that holds components and configs +## (3) define default config (related to components), e.g. default_sample_size, vae_scale_factor, num_channels_unet, num_channels_latents +## (4) inherit from model-specic loader class (e.g. StableDiffusionXLLoraLoaderMixin) +## (5) how to use together with Components_manager? +class StableDiffusionXLModularLoader( + ModularLoader, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + ModularIPAdapterMixin, +): + @property + def default_sample_size(self): + default_sample_size = 128 + if hasattr(self, "unet") and self.unet is not None: + default_sample_size = self.unet.config.sample_size + return default_sample_size + + @property + def vae_scale_factor(self): + vae_scale_factor = 8 + if hasattr(self, "vae") and self.vae is not None: + vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + return vae_scale_factor + + @property + def num_channels_unet(self): + num_channels_unet = 4 + if hasattr(self, "unet") and self.unet is not None: + num_channels_unet = self.unet.config.in_channels + return num_channels_unet + + @property + def num_channels_latents(self): + num_channels_latents = 4 + if hasattr(self, "vae") and self.vae is not None: + num_channels_latents = self.vae.config.latent_channels + return num_channels_latents + + + +# YiYi Notes: not used yet, maintain a list of schema that can be used across all pipeline blocks +SDXL_INPUTS_SCHEMA = { + "prompt": InputParam("prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"), + "prompt_2": InputParam("prompt_2", type_hint=Union[str, List[str]], description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2"), + "negative_prompt": InputParam("negative_prompt", type_hint=Union[str, List[str]], description="The prompt or prompts not to guide the image generation"), + "negative_prompt_2": InputParam("negative_prompt_2", type_hint=Union[str, List[str]], description="The negative prompt or prompts for text_encoder_2"), + "cross_attention_kwargs": InputParam("cross_attention_kwargs", type_hint=Optional[dict], description="Kwargs dictionary passed to the AttentionProcessor"), + "clip_skip": InputParam("clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"), + "image": InputParam("image", type_hint=PipelineImageInput, required=True, description="The image(s) to modify for img2img or inpainting"), + "mask_image": InputParam("mask_image", type_hint=PipelineImageInput, required=True, description="Mask image for inpainting, white pixels will be repainted"), + "generator": InputParam("generator", type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], description="Generator(s) for deterministic generation"), + "height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"), + "width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"), + "num_images_per_prompt": InputParam("num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"), + "num_inference_steps": InputParam("num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"), + "timesteps": InputParam("timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"), + "sigmas": InputParam("sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"), + "denoising_end": InputParam("denoising_end", type_hint=Optional[float], description="Fraction of denoising process to complete before termination"), + # YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999 + "strength": InputParam("strength", type_hint=float, default=0.3, description="How much to transform the reference image"), + "denoising_start": InputParam("denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"), + "latents": InputParam("latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"), + "padding_mask_crop": InputParam("padding_mask_crop", type_hint=Optional[Tuple[int, int]], description="Size of margin in crop for image and mask"), + "original_size": InputParam("original_size", type_hint=Optional[Tuple[int, int]], description="Original size of the image for SDXL's micro-conditioning"), + "target_size": InputParam("target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"), + "negative_original_size": InputParam("negative_original_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on image resolution"), + "negative_target_size": InputParam("negative_target_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on target resolution"), + "crops_coords_top_left": InputParam("crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Top-left coordinates for SDXL's micro-conditioning"), + "negative_crops_coords_top_left": InputParam("negative_crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Negative conditioning crop coordinates"), + "aesthetic_score": InputParam("aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"), + "negative_aesthetic_score": InputParam("negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"), + "eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"), + "output_type": InputParam("output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"), + "return_dict": InputParam("return_dict", type_hint=bool, default=True, description="Whether to return a StableDiffusionXLPipelineOutput"), + "ip_adapter_image": InputParam("ip_adapter_image", type_hint=PipelineImageInput, required=True, description="Image(s) to be used as IP adapter"), + "control_image": InputParam("control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"), + "control_guidance_start": InputParam("control_guidance_start", type_hint=Union[float, List[float]], default=0.0, description="When ControlNet starts applying"), + "control_guidance_end": InputParam("control_guidance_end", type_hint=Union[float, List[float]], default=1.0, description="When ControlNet stops applying"), + "controlnet_conditioning_scale": InputParam("controlnet_conditioning_scale", type_hint=Union[float, List[float]], default=1.0, description="Scale factor for ControlNet outputs"), + "guess_mode": InputParam("guess_mode", type_hint=bool, default=False, description="Enables ControlNet encoder to recognize input without prompts"), + "control_mode": InputParam("control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet") +} + + +SDXL_INTERMEDIATE_INPUTS_SCHEMA = { + "prompt_embeds": InputParam("prompt_embeds", type_hint=torch.Tensor, required=True, description="Text embeddings used to guide image generation"), + "negative_prompt_embeds": InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"), + "pooled_prompt_embeds": InputParam("pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"), + "negative_pooled_prompt_embeds": InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"), + "batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"), + "dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), + "preprocess_kwargs": InputParam("preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"), + "latents": InputParam("latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"), + "timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"), + "num_inference_steps": InputParam("num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"), + "latent_timestep": InputParam("latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"), + "image_latents": InputParam("image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"), + "mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"), + "masked_image_latents": InputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"), + "add_time_ids": InputParam("add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"), + "negative_add_time_ids": InputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"), + "timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"), + "noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"), + "crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"), + "ip_adapter_embeds": InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"), + "negative_ip_adapter_embeds": InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"), + "images": InputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], required=True, description="Generated images") +} + + +SDXL_INTERMEDIATE_OUTPUTS_SCHEMA = { + "prompt_embeds": OutputParam("prompt_embeds", type_hint=torch.Tensor, description="Text embeddings used to guide image generation"), + "negative_prompt_embeds": OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"), + "pooled_prompt_embeds": OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, description="Pooled text embeddings"), + "negative_pooled_prompt_embeds": OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"), + "batch_size": OutputParam("batch_size", type_hint=int, description="Number of prompts"), + "dtype": OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), + "image_latents": OutputParam("image_latents", type_hint=torch.Tensor, description="Latents representing reference image"), + "mask": OutputParam("mask", type_hint=torch.Tensor, description="Mask for inpainting"), + "masked_image_latents": OutputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"), + "crops_coords": OutputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"), + "timesteps": OutputParam("timesteps", type_hint=torch.Tensor, description="Timesteps for inference"), + "num_inference_steps": OutputParam("num_inference_steps", type_hint=int, description="Number of denoising steps"), + "latent_timestep": OutputParam("latent_timestep", type_hint=torch.Tensor, description="Initial noise level timestep"), + "add_time_ids": OutputParam("add_time_ids", type_hint=torch.Tensor, description="Time ids for conditioning"), + "negative_add_time_ids": OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"), + "timestep_cond": OutputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"), + "latents": OutputParam("latents", type_hint=torch.Tensor, description="Denoised latents"), + "noise": OutputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"), + "ip_adapter_embeds": OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"), + "negative_ip_adapter_embeds": OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"), + "images": OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="Generated images") +} + + +SDXL_OUTPUTS_SCHEMA = { + "images": OutputParam("images", type_hint=Union[Tuple[Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]]], StableDiffusionXLPipelineOutput], description="The final generated images") +} + diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py new file mode 100644 index 000000000000..80f1780595c2 --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py @@ -0,0 +1,119 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, Optional, Tuple, Union, Dict +from ...utils import logging +from ..modular_pipeline import SequentialPipelineBlocks + +from .denoise import StableDiffusionXLAutoDenoiseStep +from .before_denoise import StableDiffusionXLAutoBeforeDenoiseStep +from .after_denoise import StableDiffusionXLAutoDecodeStep +from .encoders import StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class StableDiffusionXLAutoPipeline(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep, StableDiffusionXLAutoBeforeDenoiseStep, StableDiffusionXLAutoDenoiseStep, StableDiffusionXLAutoDecodeStep] + block_names = ["text_encoder", "ip_adapter", "image_encoder", "before_denoise", "denoise", "after_denoise"] + + @property + def description(self): + return "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion XL.\n" + \ + "- for image-to-image generation, you need to provide either `image` or `image_latents`\n" + \ + "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n" + \ + "- to run the controlnet workflow, you need to provide `control_image`\n" + \ + "- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n" + \ + "- to run the ip_adapter workflow, you need to provide `ip_adapter_image`\n" + \ + "- for text-to-image generation, all you need to provide is `prompt`" + + + +# YiYi notes: comment out for now, work on this later +# # block mapping +# TEXT2IMAGE_BLOCKS = OrderedDict([ +# ("text_encoder", StableDiffusionXLTextEncoderStep), +# ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), +# ("input", StableDiffusionXLInputStep), +# ("set_timesteps", StableDiffusionXLSetTimestepsStep), +# ("prepare_latents", StableDiffusionXLPrepareLatentsStep), +# ("prepare_add_cond", StableDiffusionXLPrepareAdditionalConditioningStep), +# ("denoise", StableDiffusionXLDenoiseStep), +# ("decode", StableDiffusionXLDecodeStep) +# ]) + +# IMAGE2IMAGE_BLOCKS = OrderedDict([ +# ("text_encoder", StableDiffusionXLTextEncoderStep), +# ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), +# ("image_encoder", StableDiffusionXLVaeEncoderStep), +# ("input", StableDiffusionXLInputStep), +# ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), +# ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep), +# ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), +# ("denoise", StableDiffusionXLDenoiseStep), +# ("decode", StableDiffusionXLDecodeStep) +# ]) + +# INPAINT_BLOCKS = OrderedDict([ +# ("text_encoder", StableDiffusionXLTextEncoderStep), +# ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), +# ("image_encoder", StableDiffusionXLInpaintVaeEncoderStep), +# ("input", StableDiffusionXLInputStep), +# ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), +# ("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep), +# ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), +# ("denoise", StableDiffusionXLDenoiseStep), +# ("decode", StableDiffusionXLInpaintDecodeStep) +# ]) + +# CONTROLNET_BLOCKS = OrderedDict([ +# ("controlnet_input", StableDiffusionXLControlNetInputStep), +# ("denoise", StableDiffusionXLControlNetDenoiseStep), +# ]) + +# CONTROLNET_UNION_BLOCKS = OrderedDict([ +# ("controlnet_input", StableDiffusionXLControlNetUnionInputStep), +# ("denoise", StableDiffusionXLControlNetDenoiseStep), +# ]) + +# IP_ADAPTER_BLOCKS = OrderedDict([ +# ("ip_adapter", StableDiffusionXLIPAdapterStep), +# ]) + +# AUTO_BLOCKS = OrderedDict([ +# ("text_encoder", StableDiffusionXLTextEncoderStep), +# ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), +# ("image_encoder", StableDiffusionXLAutoVaeEncoderStep), +# ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), +# ("denoise", StableDiffusionXLAutoDenoiseStep), +# ("decode", StableDiffusionXLAutoDecodeStep) +# ]) + +# AUTO_CORE_BLOCKS = OrderedDict([ +# ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), +# ("denoise", StableDiffusionXLAutoDenoiseStep), +# ]) + + +# SDXL_SUPPORTED_BLOCKS = { +# "text2img": TEXT2IMAGE_BLOCKS, +# "img2img": IMAGE2IMAGE_BLOCKS, +# "inpaint": INPAINT_BLOCKS, +# "controlnet": CONTROLNET_BLOCKS, +# "controlnet_union": CONTROLNET_UNION_BLOCKS, +# "ip_adapter": IP_ADAPTER_BLOCKS, +# "auto": AUTO_BLOCKS +# } + + From 153ae34ff6d8c0832b7d2db2aabcf4e27f0eb1e4 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 10 May 2025 03:50:47 +0200 Subject: [PATCH 17/38] update __init__ --- src/diffusers/__init__.py | 48 +++++++++++++++++++++++++++++++++------ 1 file changed, 41 insertions(+), 7 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index fa3e88d999b5..7a3de0b95747 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -39,6 +39,7 @@ "loaders": ["FromOriginalModelMixin"], "models": [], "pipelines": [], + "modular_pipelines": [], "quantizers.quantization_config": [], "schedulers": [], "utils": [ @@ -254,13 +255,19 @@ "KarrasVePipeline", "LDMPipeline", "LDMSuperResolutionPipeline", - "ModularLoader", "PNDMPipeline", "RePaintPipeline", "ScoreSdeVePipeline", "StableDiffusionMixin", ] ) + _import_structure["modular_pipelines"].extend( + [ + "ModularLoader", + "ComponentSpec", + "ComponentsManager", + ] + ) _import_structure["quantizers"] = ["DiffusersQuantizer"] _import_structure["schedulers"].extend( [ @@ -509,12 +516,10 @@ "StableDiffusionXLImg2ImgPipeline", "StableDiffusionXLInpaintPipeline", "StableDiffusionXLInstructPix2PixPipeline", - "StableDiffusionXLModularLoader", "StableDiffusionXLPAGImg2ImgPipeline", "StableDiffusionXLPAGInpaintPipeline", "StableDiffusionXLPAGPipeline", "StableDiffusionXLPipeline", - "StableDiffusionXLAutoPipeline", "StableUnCLIPImg2ImgPipeline", "StableUnCLIPPipeline", "StableVideoDiffusionPipeline", @@ -541,6 +546,24 @@ ] ) + +try: + if not (is_torch_available() and is_transformers_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils import dummy_torch_and_transformers_objects # noqa F403 + + _import_structure["utils.dummy_torch_and_transformers_objects"] = [ + name for name in dir(dummy_torch_and_transformers_objects) if not name.startswith("_") + ] + +else: + _import_structure["modular_pipelines"].extend( + [ + "StableDiffusionXLAutoPipeline", + "StableDiffusionXLModularLoader", + ] + ) try: if not (is_torch_available() and is_transformers_available() and is_opencv_available()): raise OptionalDependencyNotAvailable() @@ -864,12 +887,16 @@ KarrasVePipeline, LDMPipeline, LDMSuperResolutionPipeline, - ModularLoader, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline, StableDiffusionMixin, ) + from .modular_pipelines import ( + ModularLoader, + ComponentSpec, + ComponentsManager, + ) from .quantizers import DiffusersQuantizer from .schedulers import ( AmusedScheduler, @@ -1097,12 +1124,10 @@ StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLInstructPix2PixPipeline, - StableDiffusionXLModularLoader, StableDiffusionXLPAGImg2ImgPipeline, StableDiffusionXLPAGInpaintPipeline, StableDiffusionXLPAGPipeline, StableDiffusionXLPipeline, - StableDiffusionXLAutoPipeline, StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, StableVideoDiffusionPipeline, @@ -1127,7 +1152,16 @@ WuerstchenDecoderPipeline, WuerstchenPriorPipeline, ) - + try: + if not (is_torch_available() and is_transformers_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .modular_pipelines import ( + StableDiffusionXLAutoPipeline, + StableDiffusionXLModularLoader, + ) try: if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): raise OptionalDependencyNotAvailable() From 796453cad12d62dbe48db156df925cd5392cca31 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 12 May 2025 01:14:43 +0200 Subject: [PATCH 18/38] add notes --- .../modular_pipelines/modular_pipeline_utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index 392d6dcd9521..a82f83fc38d9 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -241,6 +241,13 @@ class ConfigSpec: name: str default: Any description: Optional[str] = None + + +# YiYi Notes: both inputs and intermediates_inputs are InputParam objects +# however some fields are not relevant for intermediates_inputs +# e.g. unlike inputs, required only used in docstring for intermediate_inputs, we do not check if a required intermediate inputs is passed +# default is not used for intermediates_inputs, we only use default from inputs, so it is ignored if it is set for intermediates_inputs +# -> should we use different class for inputs and intermediates_inputs? @dataclass class InputParam: """Specification for an input parameter.""" @@ -249,7 +256,7 @@ class InputParam: default: Any = None required: bool = False description: str = "" - kwargs_type: str = None + kwargs_type: str = None # YiYi Notes: experimenting with this, not sure if we should keep it def __repr__(self): return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>" From 144eae4e0bb3368d9f617d7c54761e86128a0289 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 12 May 2025 01:16:42 +0200 Subject: [PATCH 19/38] add block state will also make sure modifed intermediates_inputs will be updated --- .../modular_pipelines/modular_pipeline.py | 241 +++++++++++++++--- 1 file changed, 206 insertions(+), 35 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 98960fe25bde..3eeff41dd1de 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -282,7 +282,7 @@ def run(self, state: PipelineState = None, output: Union[str, List[str]] = None, state = PipelineState() if not hasattr(self, "loader"): - logger.warning("Loader is not set, please call `setup_loader()` if you need to load checkpoints for your pipeline.") + logger.info("Loader is not set, please call `setup_loader()` if you need to load checkpoints for your pipeline.") self.loader = None # Make a copy of the input kwargs @@ -313,7 +313,7 @@ def run(self, state: PipelineState = None, output: Union[str, List[str]] = None, # Warn about unexpected inputs if len(passed_kwargs) > 0: - logger.warning(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.") + warnings.warn(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.") # Run the pipeline with torch.no_grad(): try: @@ -373,7 +373,6 @@ def expected_configs(self) -> List[ConfigSpec]: return [] - # YiYi TODO: can we combine inputs and intermediates_inputs? the difference is inputs are immutable @property def inputs(self) -> List[InputParam]: """List of input parameters. Must be implemented by subclasses.""" @@ -389,13 +388,16 @@ def intermediates_outputs(self) -> List[OutputParam]: """List of intermediate output parameters. Must be implemented by subclasses.""" return [] + def _get_outputs(self): + return self.intermediates_outputs + + # YiYi TODO: is it too easy for user to unintentionally override these properties? # Adding outputs attributes here for consistency between PipelineBlock/AutoPipelineBlocks/SequentialPipelineBlocks @property def outputs(self) -> List[OutputParam]: - return self.intermediates_outputs + return self._get_outputs() - @property - def required_inputs(self) -> List[str]: + def _get_required_inputs(self): input_names = [] for input_param in self.inputs: if input_param.required: @@ -403,13 +405,23 @@ def required_inputs(self) -> List[str]: return input_names @property - def required_intermediates_inputs(self) -> List[str]: + def required_inputs(self) -> List[str]: + return self._get_required_inputs() + + + def _get_required_intermediates_inputs(self): input_names = [] for input_param in self.intermediates_inputs: if input_param.required: input_names.append(input_param.name) return input_names + # YiYi TODO: maybe we do not need this, it is only used in docstring, + # intermediate_inputs is by default required, unless you manually handle it inside the block + @property + def required_intermediates_inputs(self) -> List[str]: + return self._get_required_intermediates_inputs() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: raise NotImplementedError("__call__ method must be implemented in subclasses") @@ -521,6 +533,30 @@ def add_block_state(self, state: PipelineState, block_state: BlockState): raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") param = getattr(block_state, output_param.name) state.add_intermediate(output_param.name, param, output_param.kwargs_type) + + for input_param in self.intermediates_inputs: + if hasattr(block_state, input_param.name): + param = getattr(block_state, input_param.name) + # Only add if the value is different from what's in the state + current_value = state.get_intermediate(input_param.name) + if current_value is not param: # Using identity comparison to check if object was modified + state.add_intermediate(input_param.name, param, input_param.kwargs_type) + + for input_param in self.intermediates_inputs: + if input_param.name and hasattr(block_state, input_param.name): + param = getattr(block_state, input_param.name) + # Only add if the value is different from what's in the state + current_value = state.get_intermediate(input_param.name) + if current_value is not param: # Using identity comparison to check if object was modified + state.add_intermediate(input_param.name, param, input_param.kwargs_type) + elif input_param.kwargs_type: + # if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters + # we need to first find out which inputs are and loop through them. + intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type) + for param_name, current_value in intermediates_kwargs.items(): + param = getattr(block_state, param_name) + if current_value is not param: # Using identity comparison to check if object was modified + state.add_intermediate(param_name, param, input_param.kwargs_type) def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]: @@ -550,16 +586,16 @@ def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> Li input_param.default is not None and current_param.default != input_param.default): warnings.warn( - f"Multiple different default values found for input '{input_param.name}': " - f"{current_param.default} (from block '{value_sources[input_param.name]}') and " + f"Multiple different default values found for input '{input_name}': " + f"{current_param.default} (from block '{value_sources[input_name]}') and " f"{input_param.default} (from block '{block_name}'). Using {current_param.default}." ) if current_param.default is None and input_param.default is not None: - combined_dict[input_param.name] = input_param - value_sources[input_param.name] = block_name + combined_dict[input_name] = input_param + value_sources[input_name] = block_name else: - combined_dict[input_param.name] = input_param - value_sources[input_param.name] = block_name + combined_dict[input_name] = input_param + value_sources[input_name] = block_name return list(combined_dict.values()) @@ -661,7 +697,9 @@ def required_inputs(self) -> List[str]: required_by_all.intersection_update(block_required) return list(required_by_all) - + + # YiYi TODO: maybe we do not need this, it is only used in docstring, + # intermediate_inputs is by default required, unless you manually handle it inside the block @property def required_intermediates_inputs(self) -> List[str]: first_block = next(iter(self.blocks.values())) @@ -838,14 +876,21 @@ def __repr__(self): indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) blocks_str += f" Description: {indented_desc}\n\n" - return ( - f"{header}\n" - f"{desc}\n\n" - f"{components_str}\n\n" - f"{configs_str}\n\n" - f"{blocks_str}" - f")" - ) + # Build the representation with conditional sections + result = f"{header}\n{desc}" + + # Only add components section if it has content + if components_str.strip(): + result += f"\n\n{components_str}" + + # Only add configs section if it has content + if configs_str.strip(): + result += f"\n\n{configs_str}" + + # Always add blocks section + result += f"\n\n{blocks_str})" + + return result @property @@ -867,13 +912,15 @@ class SequentialPipelineBlocks(ModularPipelineMixin): block_classes = [] block_names = [] - @property - def model_name(self): - return next(iter(self.blocks.values())).model_name @property def description(self): return "" + + @property + def model_name(self): + return next(iter(self.blocks.values())).model_name + @property def expected_components(self): @@ -929,6 +976,8 @@ def required_inputs(self) -> List[str]: return list(required_by_any) + # YiYi TODO: maybe we do not need this, it is only used in docstring, + # intermediate_inputs is by default required, unless you manually handle it inside the block @property def required_intermediates_inputs(self) -> List[str]: required_intermediates_inputs = [] @@ -960,11 +1009,15 @@ def intermediates_inputs(self) -> List[str]: def get_intermediates_inputs(self): inputs = [] outputs = set() + added_inputs = set() # Go through all blocks in order for block in self.blocks.values(): # Add inputs that aren't in outputs yet - inputs.extend(input_name for input_name in block.intermediates_inputs if input_name.name not in outputs) + for inp in block.intermediates_inputs: + if inp.name not in outputs and inp.name not in added_inputs: + inputs.append(inp) + added_inputs.add(inp.name) # Only add outputs if the block cannot be skipped should_add_outputs = True @@ -1176,14 +1229,21 @@ def __repr__(self): indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) blocks_str += f" Description: {indented_desc}\n\n" - return ( - f"{header}\n" - f"{desc}\n\n" - f"{components_str}\n\n" - f"{configs_str}\n\n" - f"{blocks_str}" - f")" - ) + # Build the representation with conditional sections + result = f"{header}\n{desc}" + + # Only add components section if it has content + if components_str.strip(): + result += f"\n\n{components_str}" + + # Only add configs section if it has content + if configs_str.strip(): + result += f"\n\n{configs_str}" + + # Always add blocks section + result += f"\n\n{blocks_str})" + + return result @property @@ -1348,7 +1408,8 @@ def required_inputs(self) -> List[str]: return list(required_by_any) - # modified from SequentialPipelineBlocks, if any additional intermediate input required by the loop is required by the block + # YiYi TODO: maybe we do not need this, it is only used in docstring, + # intermediate_inputs is by default required, unless you manually handle it inside the block @property def required_intermediates_inputs(self) -> List[str]: required_intermediates_inputs = [] @@ -1384,6 +1445,22 @@ def __init__(self): for block_name, block_cls in zip(self.block_names, self.block_classes): blocks[block_name] = block_cls() self.blocks = blocks + + @classmethod + def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "LoopSequentialPipelineBlocks": + """Creates a LoopSequentialPipelineBlocks instance from a dictionary of blocks. + + Args: + blocks_dict: Dictionary mapping block names to block instances + + Returns: + A new LoopSequentialPipelineBlocks instance + """ + instance = cls() + instance.block_classes = [block.__class__ for block in blocks_dict.values()] + instance.block_names = list(blocks_dict.keys()) + instance.blocks = blocks_dict + return instance def loop_step(self, components, state: PipelineState, **kwargs): @@ -1455,6 +1532,100 @@ def add_block_state(self, state: PipelineState, block_state: BlockState): param = getattr(block_state, output_param.name) state.add_intermediate(output_param.name, param, output_param.kwargs_type) + for input_param in self.intermediates_inputs: + if input_param.name and hasattr(block_state, input_param.name): + param = getattr(block_state, input_param.name) + # Only add if the value is different from what's in the state + current_value = state.get_intermediate(input_param.name) + if current_value is not param: # Using identity comparison to check if object was modified + state.add_intermediate(input_param.name, param, input_param.kwargs_type) + elif input_param.kwargs_type: + # if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters + # we need to first find out which inputs are and loop through them. + intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type) + for param_name, current_value in intermediates_kwargs.items(): + if not hasattr(block_state, param_name): + continue + param = getattr(block_state, param_name) + if current_value is not param: # Using identity comparison to check if object was modified + state.add_intermediate(param_name, param, input_param.kwargs_type) + + + @property + def doc(self): + return make_doc_string( + self.inputs, + self.intermediates_inputs, + self.outputs, + self.description, + class_name=self.__class__.__name__, + expected_components=self.expected_components, + expected_configs=self.expected_configs + ) + + # modified from SequentialPipelineBlocks, + #(does not need trigger_inputs related part so removed them, + # do not need to support auto block for loop blocks) + def __repr__(self): + class_name = self.__class__.__name__ + base_class = self.__class__.__bases__[0].__name__ + header = ( + f"{class_name}(\n Class: {base_class}\n" + if base_class and base_class != "object" + else f"{class_name}(\n" + ) + + # Format description with proper indentation + desc_lines = self.description.split('\n') + desc = [] + # First line with "Description:" label + desc.append(f" Description: {desc_lines[0]}") + # Subsequent lines with proper indentation + if len(desc_lines) > 1: + desc.extend(f" {line}" for line in desc_lines[1:]) + desc = '\n'.join(desc) + '\n' + + # Components section - focus only on expected components + expected_components = getattr(self, "expected_components", []) + components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) + + # Configs section - use format_configs with add_empty_lines=False + expected_configs = getattr(self, "expected_configs", []) + configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) + + # Blocks section - moved to the end with simplified format + blocks_str = " Blocks:\n" + for i, (name, block) in enumerate(self.blocks.items()): + + # For SequentialPipelineBlocks, show execution order + blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" + + # Add block description + desc_lines = block.description.split('\n') + indented_desc = desc_lines[0] + if len(desc_lines) > 1: + indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) + blocks_str += f" Description: {indented_desc}\n\n" + + # Build the representation with conditional sections + result = f"{header}\n{desc}" + + # Only add components section if it has content + if components_str.strip(): + result += f"\n\n{components_str}" + + # Only add configs section if it has content + if configs_str.strip(): + result += f"\n\n{configs_str}" + + # Always add blocks section + result += f"\n\n{blocks_str})" + + return result + + + + # YiYi TODO: # 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess) # 2. do we need ConfigSpec? seems pretty unnecessrary for loader, can just add and kwargs to the loader From 522e82762566597de63afd185f9bc02589035674 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 12 May 2025 01:17:45 +0200 Subject: [PATCH 20/38] move block mappings to its own file --- .../modular_pipeline_block_mappings.py | 128 ++++++++++++++++++ .../modular_pipeline_presets.py | 76 ----------- 2 files changed, 128 insertions(+), 76 deletions(-) create mode 100644 src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_block_mappings.py diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_block_mappings.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_block_mappings.py new file mode 100644 index 000000000000..c739a24e9759 --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_block_mappings.py @@ -0,0 +1,128 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict + +# Import all the necessary block classes +from .denoise import ( + StableDiffusionXLAutoDenoiseStep, + StableDiffusionXLDenoiseStep, + StableDiffusionXLControlNetDenoiseStep +) +from .before_denoise import ( + StableDiffusionXLAutoBeforeDenoiseStep, + StableDiffusionXLInputStep, + StableDiffusionXLSetTimestepsStep, + StableDiffusionXLPrepareLatentsStep, + StableDiffusionXLPrepareAdditionalConditioningStep, + StableDiffusionXLImg2ImgSetTimestepsStep, + StableDiffusionXLImg2ImgPrepareLatentsStep, + StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, + StableDiffusionXLInpaintPrepareLatentsStep, + StableDiffusionXLControlNetInputStep, + StableDiffusionXLControlNetUnionInputStep +) +from .encoders import ( + StableDiffusionXLTextEncoderStep, + StableDiffusionXLAutoIPAdapterStep, + StableDiffusionXLAutoVaeEncoderStep, + StableDiffusionXLVaeEncoderStep, + StableDiffusionXLInpaintVaeEncoderStep, + StableDiffusionXLIPAdapterStep +) +from .after_denoise import ( + StableDiffusionXLDecodeStep, + StableDiffusionXLInpaintDecodeStep +) +from .after_denoise import StableDiffusionXLAutoDecodeStep + + +# YiYi notes: comment out for now, work on this later +# block mapping +TEXT2IMAGE_BLOCKS = OrderedDict([ + ("text_encoder", StableDiffusionXLTextEncoderStep), + ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), + ("input", StableDiffusionXLInputStep), + ("set_timesteps", StableDiffusionXLSetTimestepsStep), + ("prepare_latents", StableDiffusionXLPrepareLatentsStep), + ("prepare_add_cond", StableDiffusionXLPrepareAdditionalConditioningStep), + ("denoise", StableDiffusionXLDenoiseStep), + ("decode", StableDiffusionXLDecodeStep) +]) + +IMAGE2IMAGE_BLOCKS = OrderedDict([ + ("text_encoder", StableDiffusionXLTextEncoderStep), + ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), + ("image_encoder", StableDiffusionXLVaeEncoderStep), + ("input", StableDiffusionXLInputStep), + ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), + ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep), + ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), + ("denoise", StableDiffusionXLDenoiseStep), + ("decode", StableDiffusionXLDecodeStep) +]) + +INPAINT_BLOCKS = OrderedDict([ + ("text_encoder", StableDiffusionXLTextEncoderStep), + ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), + ("image_encoder", StableDiffusionXLInpaintVaeEncoderStep), + ("input", StableDiffusionXLInputStep), + ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), + ("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep), + ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), + ("denoise", StableDiffusionXLDenoiseStep), + ("decode", StableDiffusionXLInpaintDecodeStep) +]) + +CONTROLNET_BLOCKS = OrderedDict([ + ("controlnet_input", StableDiffusionXLControlNetInputStep), + ("denoise", StableDiffusionXLControlNetDenoiseStep), +]) + +CONTROLNET_UNION_BLOCKS = OrderedDict([ + ("controlnet_input", StableDiffusionXLControlNetUnionInputStep), + ("denoise", StableDiffusionXLControlNetDenoiseStep), +]) + +IP_ADAPTER_BLOCKS = OrderedDict([ + ("ip_adapter", StableDiffusionXLIPAdapterStep), +]) + +AUTO_BLOCKS = OrderedDict([ + ("text_encoder", StableDiffusionXLTextEncoderStep), + ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), + ("image_encoder", StableDiffusionXLAutoVaeEncoderStep), + ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), + ("denoise", StableDiffusionXLAutoDenoiseStep), + ("decode", StableDiffusionXLAutoDecodeStep) +]) + +AUTO_CORE_BLOCKS = OrderedDict([ + ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), + ("denoise", StableDiffusionXLAutoDenoiseStep), +]) + + +SDXL_SUPPORTED_BLOCKS = { + "text2img": TEXT2IMAGE_BLOCKS, + "img2img": IMAGE2IMAGE_BLOCKS, + "inpaint": INPAINT_BLOCKS, + "controlnet": CONTROLNET_BLOCKS, + "controlnet_union": CONTROLNET_UNION_BLOCKS, + "ip_adapter": IP_ADAPTER_BLOCKS, + "auto": AUTO_BLOCKS +} + + + diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py index 80f1780595c2..6ea327047740 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py @@ -40,80 +40,4 @@ def description(self): -# YiYi notes: comment out for now, work on this later -# # block mapping -# TEXT2IMAGE_BLOCKS = OrderedDict([ -# ("text_encoder", StableDiffusionXLTextEncoderStep), -# ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), -# ("input", StableDiffusionXLInputStep), -# ("set_timesteps", StableDiffusionXLSetTimestepsStep), -# ("prepare_latents", StableDiffusionXLPrepareLatentsStep), -# ("prepare_add_cond", StableDiffusionXLPrepareAdditionalConditioningStep), -# ("denoise", StableDiffusionXLDenoiseStep), -# ("decode", StableDiffusionXLDecodeStep) -# ]) - -# IMAGE2IMAGE_BLOCKS = OrderedDict([ -# ("text_encoder", StableDiffusionXLTextEncoderStep), -# ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), -# ("image_encoder", StableDiffusionXLVaeEncoderStep), -# ("input", StableDiffusionXLInputStep), -# ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), -# ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep), -# ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), -# ("denoise", StableDiffusionXLDenoiseStep), -# ("decode", StableDiffusionXLDecodeStep) -# ]) - -# INPAINT_BLOCKS = OrderedDict([ -# ("text_encoder", StableDiffusionXLTextEncoderStep), -# ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), -# ("image_encoder", StableDiffusionXLInpaintVaeEncoderStep), -# ("input", StableDiffusionXLInputStep), -# ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), -# ("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep), -# ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), -# ("denoise", StableDiffusionXLDenoiseStep), -# ("decode", StableDiffusionXLInpaintDecodeStep) -# ]) - -# CONTROLNET_BLOCKS = OrderedDict([ -# ("controlnet_input", StableDiffusionXLControlNetInputStep), -# ("denoise", StableDiffusionXLControlNetDenoiseStep), -# ]) - -# CONTROLNET_UNION_BLOCKS = OrderedDict([ -# ("controlnet_input", StableDiffusionXLControlNetUnionInputStep), -# ("denoise", StableDiffusionXLControlNetDenoiseStep), -# ]) - -# IP_ADAPTER_BLOCKS = OrderedDict([ -# ("ip_adapter", StableDiffusionXLIPAdapterStep), -# ]) - -# AUTO_BLOCKS = OrderedDict([ -# ("text_encoder", StableDiffusionXLTextEncoderStep), -# ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), -# ("image_encoder", StableDiffusionXLAutoVaeEncoderStep), -# ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), -# ("denoise", StableDiffusionXLAutoDenoiseStep), -# ("decode", StableDiffusionXLAutoDecodeStep) -# ]) - -# AUTO_CORE_BLOCKS = OrderedDict([ -# ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), -# ("denoise", StableDiffusionXLAutoDenoiseStep), -# ]) - - -# SDXL_SUPPORTED_BLOCKS = { -# "text2img": TEXT2IMAGE_BLOCKS, -# "img2img": IMAGE2IMAGE_BLOCKS, -# "inpaint": INPAINT_BLOCKS, -# "controlnet": CONTROLNET_BLOCKS, -# "controlnet_union": CONTROLNET_UNION_BLOCKS, -# "ip_adapter": IP_ADAPTER_BLOCKS, -# "auto": AUTO_BLOCKS -# } - From 5cde77f9159d9bf1deeb948a4db79d109df461d7 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 13 May 2025 01:52:51 +0200 Subject: [PATCH 21/38] make inputs truly immutable, remove the output logic in sequential pipeline, and update so that intermediates_outputs are only new variables --- .../modular_pipelines/modular_pipeline.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 3eeff41dd1de..5dcb903db495 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -17,6 +17,7 @@ from collections import OrderedDict from dataclasses import dataclass, field from typing import Any, Dict, List, Tuple, Union, Optional, Type +from copy import deepcopy import torch @@ -109,7 +110,9 @@ def add_intermediate(self, key: str, value: Any, kwargs_type: str = None): self.intermediate_kwargs[kwargs_type].append(key) def get_input(self, key: str, default: Any = None) -> Any: - return self.inputs.get(key, default) + value = self.inputs.get(key, default) + if value is not None: + return deepcopy(value) def get_inputs(self, keys: List[str], default: Any = None) -> Dict[str, Any]: return {key: self.inputs.get(key, default) for key in keys} @@ -483,6 +486,7 @@ def doc(self): ) + # YiYi TODO: input and inteermediate inputs with same name? should warn? def get_block_state(self, state: PipelineState) -> dict: """Get all inputs and intermediates in one dictionary""" data = {} @@ -1032,14 +1036,21 @@ def get_intermediates_inputs(self): @property def intermediates_outputs(self) -> List[str]: - named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] + named_outputs = [] + for name, block in self.blocks.items(): + inp_names = set([inp.name for inp in block.intermediates_inputs]) + # so we only need to list new variables as intermediates_outputs, but if user wants to list these they modified it's still fine (a.k.a we don't enforce) + # filter out them here so they do not end up as intermediates_outputs + if name not in inp_names: + named_outputs.append((name, block.intermediates_outputs)) combined_outputs = combine_outputs(*named_outputs) return combined_outputs + # YiYi TODO: I think we can remove the outputs property @property def outputs(self) -> List[str]: - return next(reversed(self.blocks.values())).intermediates_outputs - + # return next(reversed(self.blocks.values())).intermediates_outputs + return self.intermediates_outputs @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: for block_name, block in self.blocks.items(): From 58358c2d003f7a25120aea9c4545571d6feefe21 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 13 May 2025 01:57:47 +0200 Subject: [PATCH 22/38] decode block, if skip decoding do not need to update latent --- .../stable_diffusion_xl/after_denoise.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/after_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/after_denoise.py index 9746832506d7..6ce59b5c35b9 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/after_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/after_denoise.py @@ -98,16 +98,17 @@ def __call__(self, components, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) if not block_state.output_type == "latent": + latents = block_state.latents # make sure the VAE is in float32 mode, as it overflows in float16 block_state.needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast if block_state.needs_upcasting: self.upcast_vae(components) - block_state.latents = block_state.latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype) - elif block_state.latents.dtype != components.vae.dtype: + latents = latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != components.vae.dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - components.vae = components.vae.to(block_state.latents.dtype) + components.vae = components.vae.to(latents.dtype) # unscale/denormalize the latents # denormalize with the mean and std if available and not None @@ -119,16 +120,16 @@ def __call__(self, components, state: PipelineState) -> PipelineState: ) if block_state.has_latents_mean and block_state.has_latents_std: block_state.latents_mean = ( - torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1).to(block_state.latents.device, block_state.latents.dtype) + torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) block_state.latents_std = ( - torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(block_state.latents.device, block_state.latents.dtype) + torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) - block_state.latents = block_state.latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean + latents = latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean else: - block_state.latents = block_state.latents / components.vae.config.scaling_factor + latents = latents / components.vae.config.scaling_factor - block_state.images = components.vae.decode(block_state.latents, return_dict=False)[0] + block_state.images = components.vae.decode(latents, return_dict=False)[0] # cast back to fp16 if needed if block_state.needs_upcasting: @@ -186,6 +187,7 @@ def __call__(self, components, state: PipelineState) -> PipelineState: return components, state +# YiYi TODO: remove this, we don't need this in modular class StableDiffusionXLOutputStep(PipelineBlock): model_name = "stable-diffusion-xl" From 506a8ea09c19d806103c23e69d3dd52aa7e84110 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 13 May 2025 04:36:06 +0200 Subject: [PATCH 23/38] fix imports --- .../pipelines/stable_diffusion_xl/__init__.py | 24 ------------------- 1 file changed, 24 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/__init__.py b/src/diffusers/pipelines/stable_diffusion_xl/__init__.py index 006836fe30d4..8088fbcfceba 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/__init__.py @@ -29,18 +29,6 @@ _import_structure["pipeline_stable_diffusion_xl_img2img"] = ["StableDiffusionXLImg2ImgPipeline"] _import_structure["pipeline_stable_diffusion_xl_inpaint"] = ["StableDiffusionXLInpaintPipeline"] _import_structure["pipeline_stable_diffusion_xl_instruct_pix2pix"] = ["StableDiffusionXLInstructPix2PixPipeline"] - _import_structure["pipeline_stable_diffusion_xl_modular"] = [ - "StableDiffusionXLControlNetDenoiseStep", - "StableDiffusionXLDecodeLatentsStep", - "StableDiffusionXLDenoiseStep", - "StableDiffusionXLInputStep", - "StableDiffusionXLModularLoader", - "StableDiffusionXLPrepareAdditionalConditioningStep", - "StableDiffusionXLPrepareLatentsStep", - "StableDiffusionXLSetTimestepsStep", - "StableDiffusionXLTextEncoderStep", - "StableDiffusionXLAutoPipeline", - ] if is_transformers_available() and is_flax_available(): from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState @@ -60,18 +48,6 @@ from .pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipeline from .pipeline_stable_diffusion_xl_inpaint import StableDiffusionXLInpaintPipeline from .pipeline_stable_diffusion_xl_instruct_pix2pix import StableDiffusionXLInstructPix2PixPipeline - from .pipeline_stable_diffusion_xl_modular import ( - StableDiffusionXLControlNetDenoiseStep, - StableDiffusionXLDecodeLatentsStep, - StableDiffusionXLDenoiseStep, - StableDiffusionXLInputStep, - StableDiffusionXLModularLoader, - StableDiffusionXLPrepareAdditionalConditioningStep, - StableDiffusionXLPrepareLatentsStep, - StableDiffusionXLSetTimestepsStep, - StableDiffusionXLTextEncoderStep, - StableDiffusionXLAutoPipeline, - ) try: if not (is_transformers_available() and is_flax_available()): From e2491af650b33c43294f0aaac02f0b7fdbbcf7e0 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 13 May 2025 20:42:57 +0200 Subject: [PATCH 24/38] fix import --- src/diffusers/pipelines/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 0567eb687c62..a988fb6702aa 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -47,7 +47,6 @@ "AutoPipelineForInpainting", "AutoPipelineForText2Image", ] - _import_structure["modular_pipeline"] = ["ModularLoader"] _import_structure["consistency_models"] = ["ConsistencyModelPipeline"] _import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"] _import_structure["ddim"] = ["DDIMPipeline"] @@ -481,7 +480,6 @@ from .deprecated import KarrasVePipeline, LDMPipeline, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline from .dit import DiTPipeline from .latent_diffusion import LDMSuperResolutionPipeline - from .modular_pipeline import ModularLoader from .pipeline_utils import ( AudioPipelineOutput, DiffusionPipeline, From a0deefb6061408a5ff6523ceed24a0fa31c30b20 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 13 May 2025 20:51:21 +0200 Subject: [PATCH 25/38] fix more --- src/diffusers/pipelines/__init__.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index a988fb6702aa..011f23ed371c 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -329,8 +329,6 @@ "StableDiffusionXLInpaintPipeline", "StableDiffusionXLInstructPix2PixPipeline", "StableDiffusionXLPipeline", - "StableDiffusionXLModularLoader", - "StableDiffusionXLAutoPipeline", ] ) _import_structure["stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"] @@ -704,9 +702,7 @@ StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLInstructPix2PixPipeline, - StableDiffusionXLModularLoader, StableDiffusionXLPipeline, - StableDiffusionXLAutoPipeline, ) from .stable_video_diffusion import StableVideoDiffusionPipeline from .t2i_adapter import ( From a7fb2d2a2243d4687a2b9c05ca0fdec21fdb9ffb Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 13 May 2025 22:15:54 +0200 Subject: [PATCH 26/38] remove the output step --- .../stable_diffusion_xl/after_denoise.py | 56 ++----------------- .../stable_diffusion_xl/modular_loader.py | 1 - 2 files changed, 5 insertions(+), 52 deletions(-) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/after_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/after_denoise.py index 6ce59b5c35b9..ca848e20984f 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/after_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/after_denoise.py @@ -41,7 +41,7 @@ -class StableDiffusionXLDecodeLatentsStep(PipelineBlock): +class StableDiffusionXLDecodeStep(PipelineBlock): model_name = "stable-diffusion-xl" @@ -187,63 +187,17 @@ def __call__(self, components, state: PipelineState) -> PipelineState: return components, state -# YiYi TODO: remove this, we don't need this in modular -class StableDiffusionXLOutputStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return "final step to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or a plain tuple." - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [InputParam("return_dict", default=True)] - - @property - def intermediates_inputs(self) -> List[str]: - return [InputParam("images", required=True, type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images from the decode step.")] - - @property - def intermediates_outputs(self) -> List[str]: - return [OutputParam("images", description="The final images output, can be a tuple or a `StableDiffusionXLPipelineOutput`")] - - - @torch.no_grad() - def __call__(self, components, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - if not block_state.return_dict: - block_state.images = (block_state.images,) - else: - block_state.images = StableDiffusionXLPipelineOutput(images=block_state.images) - self.add_block_state(state, block_state) - return components, state - - -# After denoise -class StableDiffusionXLDecodeStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLDecodeLatentsStep, StableDiffusionXLOutputStep] - block_names = ["decode", "output"] - - @property - def description(self): - return """Decode step that decode the denoised latents into images outputs. -This is a sequential pipeline blocks: - - `StableDiffusionXLDecodeLatentsStep` is used to decode the denoised latents into images - - `StableDiffusionXLOutputStep` is used to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or a plain tuple.""" - class StableDiffusionXLInpaintDecodeStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLDecodeLatentsStep, StableDiffusionXLInpaintOverlayMaskStep, StableDiffusionXLOutputStep] - block_names = ["decode", "mask_overlay", "output"] + block_classes = [StableDiffusionXLDecodeStep, StableDiffusionXLInpaintOverlayMaskStep] + block_names = ["decode", "mask_overlay"] @property def description(self): return "Inpaint decode step that decode the denoised latents into images outputs.\n" + \ "This is a sequential pipeline blocks:\n" + \ - " - `StableDiffusionXLDecodeLatentsStep` is used to decode the denoised latents into images\n" + \ - " - `StableDiffusionXLInpaintOverlayMaskStep` is used to overlay the mask on the image\n" + \ - " - `StableDiffusionXLOutputStep` is used to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or a plain tuple." + " - `StableDiffusionXLDecodeStep` is used to decode the denoised latents into images\n" + \ + " - `StableDiffusionXLInpaintOverlayMaskStep` is used to overlay the mask on the image" class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks): diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py index 53f27571092a..4af942af64e6 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py @@ -107,7 +107,6 @@ def num_channels_latents(self): "negative_aesthetic_score": InputParam("negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"), "eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"), "output_type": InputParam("output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"), - "return_dict": InputParam("return_dict", type_hint=bool, default=True, description="Whether to return a StableDiffusionXLPipelineOutput"), "ip_adapter_image": InputParam("ip_adapter_image", type_hint=PipelineImageInput, required=True, description="Image(s) to be used as IP adapter"), "control_image": InputParam("control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"), "control_guidance_start": InputParam("control_guidance_start", type_hint=Union[float, List[float]], default=0.0, description="When ControlNet starts applying"), From 8ad14a52cbc3b3e0d7f97305dc95fee629564b97 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 13 May 2025 23:25:56 +0200 Subject: [PATCH 27/38] make generator intermediates (it is mutable) --- .../stable_diffusion_xl/before_denoise.py | 6 +++--- .../modular_pipelines/stable_diffusion_xl/denoise.py | 4 ++-- .../modular_pipelines/stable_diffusion_xl/encoders.py | 8 +++++--- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py index 6809b4cd8e2e..8f083f1870e7 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py @@ -440,7 +440,6 @@ def description(self) -> str: @property def inputs(self) -> List[Tuple[str, Any]]: return [ - InputParam("generator"), InputParam("latents"), InputParam("num_images_per_prompt", default=1), InputParam("denoising_start"), @@ -459,6 +458,7 @@ def inputs(self) -> List[Tuple[str, Any]]: @property def intermediates_inputs(self) -> List[str]: return [ + InputParam("generator"), InputParam( "batch_size", required=True, @@ -733,7 +733,6 @@ def description(self) -> str: @property def inputs(self) -> List[Tuple[str, Any]]: return [ - InputParam("generator"), InputParam("latents"), InputParam("num_images_per_prompt", default=1), InputParam("denoising_start"), @@ -742,6 +741,7 @@ def inputs(self) -> List[Tuple[str, Any]]: @property def intermediates_inputs(self) -> List[InputParam]: return [ + InputParam("generator"), InputParam("latent_timestep", required=True, type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step."), InputParam("image_latents", required=True, type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step."), InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."), @@ -879,7 +879,6 @@ def inputs(self) -> List[InputParam]: return [ InputParam("height"), InputParam("width"), - InputParam("generator"), InputParam("latents"), InputParam("num_images_per_prompt", default=1), ] @@ -887,6 +886,7 @@ def inputs(self) -> List[InputParam]: @property def intermediates_inputs(self) -> List[InputParam]: return [ + InputParam("generator"), InputParam( "batch_size", required=True, diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py index f605d0ab00aa..b29920764acb 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py @@ -485,13 +485,13 @@ def description(self) -> str: @property def inputs(self) -> List[Tuple[str, Any]]: return [ - InputParam("generator"), InputParam("eta", default=0.0), ] @property def intermediates_inputs(self) -> List[str]: return [ + InputParam("generator"), InputParam( "latents", required=True, @@ -554,13 +554,13 @@ def description(self) -> str: @property def inputs(self) -> List[Tuple[str, Any]]: return [ - InputParam("generator"), InputParam("eta", default=0.0), ] @property def intermediates_inputs(self) -> List[str]: return [ + InputParam("generator"), InputParam( "timesteps", required=True, diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py index 3c84fc71c8af..ca4efe2c4a7f 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py @@ -568,7 +568,6 @@ def expected_components(self) -> List[ComponentSpec]: def inputs(self) -> List[InputParam]: return [ InputParam("image", required=True), - InputParam("generator"), InputParam("height"), InputParam("width"), ] @@ -576,6 +575,7 @@ def inputs(self) -> List[InputParam]: @property def intermediates_inputs(self) -> List[InputParam]: return [ + InputParam("generator"), InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), InputParam("preprocess_kwargs", type_hint=Optional[dict], description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]")] @@ -680,7 +680,6 @@ def inputs(self) -> List[InputParam]: return [ InputParam("height"), InputParam("width"), - InputParam("generator"), InputParam("image", required=True), InputParam("mask_image", required=True), InputParam("padding_mask_crop"), @@ -688,7 +687,10 @@ def inputs(self) -> List[InputParam]: @property def intermediates_inputs(self) -> List[InputParam]: - return [InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs")] + return [ + InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), + InputParam("generator"), + ] @property def intermediates_outputs(self) -> List[OutputParam]: From 96ce6744fe4c7a569fd1cb5e42ce7d188b85eb1e Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 15 May 2025 00:45:45 +0200 Subject: [PATCH 28/38] after_denoise -> decoders --- .../modular_pipelines/stable_diffusion_xl/__init__.py | 4 ++-- .../stable_diffusion_xl/{after_denoise.py => decoders.py} | 0 .../stable_diffusion_xl/modular_pipeline_presets.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) rename src/diffusers/modular_pipelines/stable_diffusion_xl/{after_denoise.py => decoders.py} (100%) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py index 6d06c1f2e3df..f3f961d61a13 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py @@ -24,7 +24,7 @@ _import_structure["modular_pipeline_presets"] = ["StableDiffusionXLAutoPipeline"] _import_structure["modular_loader"] = ["StableDiffusionXLModularLoader"] _import_structure["encoders"] = ["StableDiffusionXLAutoIPAdapterStep", "StableDiffusionXLTextEncoderStep", "StableDiffusionXLAutoVaeEncoderStep"] - _import_structure["after_denoise"] = ["StableDiffusionXLAutoDecodeStep"] + _import_structure["decoders"] = ["StableDiffusionXLAutoDecodeStep"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -36,7 +36,7 @@ from .modular_pipeline_presets import StableDiffusionXLAutoPipeline from .modular_loader import StableDiffusionXLModularLoader from .encoders import StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoVaeEncoderStep - from .after_denoise import StableDiffusionXLAutoDecodeStep + from .decoders import StableDiffusionXLAutoDecodeStep else: import sys diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/after_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py similarity index 100% rename from src/diffusers/modular_pipelines/stable_diffusion_xl/after_denoise.py rename to src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py index 6ea327047740..637c7ac306d7 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py @@ -18,7 +18,7 @@ from .denoise import StableDiffusionXLAutoDenoiseStep from .before_denoise import StableDiffusionXLAutoBeforeDenoiseStep -from .after_denoise import StableDiffusionXLAutoDecodeStep +from .decoders import StableDiffusionXLAutoDecodeStep from .encoders import StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -26,7 +26,7 @@ class StableDiffusionXLAutoPipeline(SequentialPipelineBlocks): block_classes = [StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep, StableDiffusionXLAutoBeforeDenoiseStep, StableDiffusionXLAutoDenoiseStep, StableDiffusionXLAutoDecodeStep] - block_names = ["text_encoder", "ip_adapter", "image_encoder", "before_denoise", "denoise", "after_denoise"] + block_names = ["text_encoder", "ip_adapter", "image_encoder", "before_denoise", "denoise", "decoder"] @property def description(self): From 27c1158b23fc06c03a1bb8f9d730d22c394421f5 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 18 May 2025 18:50:03 +0200 Subject: [PATCH 29/38] add a to-do for guider cconfig mixin --- src/diffusers/hooks/layer_skip.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/hooks/layer_skip.py b/src/diffusers/hooks/layer_skip.py index c50d2b7471e4..65a99464ba2f 100644 --- a/src/diffusers/hooks/layer_skip.py +++ b/src/diffusers/hooks/layer_skip.py @@ -30,6 +30,8 @@ _LAYER_SKIP_HOOK = "layer_skip_hook" +# Aryan/YiYi TODO: we need to make guider class a config mixin so I think this is not needed +# either remove or make it serializable @dataclass class LayerSkipConfig: r""" From d0fbf745e6e27185a8c465ced3373e2f77cf37e2 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 18 May 2025 18:52:12 +0200 Subject: [PATCH 30/38] refactor component spec: replace create/create_from_pretrained/create_from_config to just create and load method --- .../modular_pipeline_utils.py | 72 ++++++++----------- 1 file changed, 31 insertions(+), 41 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index a82f83fc38d9..0c6d1b585589 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -71,34 +71,31 @@ def __eq__(self, other): self.default_creation_method == other.default_creation_method) @classmethod - def from_component(cls, name: str, component: torch.nn.Module) -> Any: - """Create a ComponentSpec from a Component created by `create` method.""" + def from_component(cls, name: str, component: Any) -> Any: + """Create a ComponentSpec from a Component created by `create` or `load` method.""" if not hasattr(component, "_diffusers_load_id"): - raise ValueError("Component is not created by `create` method") + raise ValueError("Component is not created by `create` or `load` method") + # throw a error if component is created with `create` method but not a subclass of ConfigMixin + # YiYi TODO: remove this check if we remove support for non configmixin in `create()` method + if component._diffusers_load_id == "null" and not isinstance(component, ConfigMixin): + raise ValueError( + "We currently only support creating ComponentSpec from a component with " + "created with `ComponentSpec.load` method" + "or created with `ComponentSpec.create` and a subclass of ConfigMixin" + ) type_hint = component.__class__ + default_creation_method = "from_config" if component._diffusers_load_id == "null" else "from_pretrained" - if component._diffusers_load_id == "null" and isinstance(component, ConfigMixin): + if isinstance(component, ConfigMixin): config = component.config else: config = None load_spec = cls.decode_load_id(component._diffusers_load_id) - return cls(name=name, type_hint=type_hint, config=config, **load_spec) - - @classmethod - def from_load_id(cls, load_id: str, name: Optional[str] = None) -> Any: - """Create a ComponentSpec from a load_id string.""" - if load_id == "null": - raise ValueError("Cannot create ComponentSpec from null load_id") - - # Decode the load_id into a dictionary of loading fields - load_fields = cls.decode_load_id(load_id) - - # Create a new ComponentSpec instance with the decoded fields - return cls(name=name, **load_fields) + return cls(name=name, type_hint=type_hint, config=config, default_creation_method=default_creation_method, **load_spec) @classmethod def loading_fields(cls) -> List[str]: @@ -137,7 +134,7 @@ def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]: "revision": "revision" } If a segment value is "null", it's replaced with None. - Returns None if load_id is "null" (indicating component not loaded from pretrained). + Returns None if load_id is "null" (indicating component not created with `load` method). """ # Get all loading fields in order @@ -158,20 +155,12 @@ def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]: return result - # YiYi TODO: add validator - def create(self, **kwargs) -> Any: - """Create the component using the preferred creation method.""" - - # from_pretrained creation - if self.default_creation_method == "from_pretrained": - return self.create_from_pretrained(**kwargs) - elif self.default_creation_method == "from_config": - # from_config creation - return self.create_from_config(**kwargs) - else: - raise ValueError(f"Invalid creation method: {self.default_creation_method}") - def create_from_config(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any: + # YiYi TODO: I think we should only support ConfigMixin for this method (after we make guider and image_processors config mixin) + # otherwise we cannot do spec -> spec.create() -> component -> ComponentSpec.from_component(component) + # the config info is lost in the process + # remove error check in from_component spec and ModularLoader.update() if we remove support for non configmixin in `create()` method + def create(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any: """Create component using from_config with config.""" if self.type_hint is None or not isinstance(self.type_hint, type): @@ -201,34 +190,35 @@ def create_from_config(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] return component # YiYi TODO: add guard for type of model, if it is supported by from_pretrained - def create_from_pretrained(self, **kwargs) -> Any: - """Create component using from_pretrained.""" + def load(self, **kwargs) -> Any: + """Load component using from_pretrained.""" + # select loading fields from kwargs passed from user: e.g. repo, subfolder, variant, revision, note the list could change passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs} + # merge loading field value in the spec with user passed values to create load_kwargs load_kwargs = {key: passed_loading_kwargs.get(key, getattr(self, key)) for key in self.loading_fields()} # repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path repo = load_kwargs.pop("repo", None) if repo is None: - raise ValueError(f"`repo` info is required when using from_pretrained creation method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)") + raise ValueError(f"`repo` info is required when using `load` method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)") if self.type_hint is None: try: from diffusers import AutoModel component = AutoModel.from_pretrained(repo, **load_kwargs, **kwargs) except Exception as e: - raise ValueError(f"Error creating {self.name} without `type_hint` from pretrained: {e}") + raise ValueError(f"Unable to load {self.name} without `type_hint`: {e}") + # update type_hint if AutoModel load successfully self.type_hint = component.__class__ else: try: component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs) except Exception as e: - raise ValueError(f"Error creating {self.name}[{self.type_hint.__name__}] from pretrained: {e}") + raise ValueError(f"Unable to load {self.name} using load method: {e}") - if repo != self.repo: - self.repo = repo - for k, v in passed_loading_kwargs.items(): - if v is not None: - setattr(self, k, v) + self.repo = repo + for k, v in load_kwargs.items(): + setattr(self, k, v) component._diffusers_load_id = self.load_id return component From 163341d3dd6c7ca8d375630a3b41363d1da3c9ce Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 18 May 2025 18:58:26 +0200 Subject: [PATCH 31/38] refactor modular loader: 1. load only load (pretrained components only if not specific names) 2. update acceept create spec 3. move the updte _componeent_spec logic outside register_components to each method that create/update the component: __init__/update/load --- .../modular_pipelines/modular_pipeline.py | 124 ++++++++++++------ 1 file changed, 85 insertions(+), 39 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 5dcb903db495..1c67a3871764 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -1651,54 +1651,68 @@ class ModularLoader(ConfigMixin, PushToHubMixin): def register_components(self, **kwargs): """ - Register components with their corresponding specs. - This method is called when component changed or __init__ is called. - + Register components with their corresponding specifications. + + This method is responsible for: + 1. Sets component objects as attributes on the loader (e.g., self.unet = unet) + 2. Updates the modular_model_index.json configuration for serialization + 4. Adds components to the component manager if one is attached + + This method is called when: + - Components are first initialized in __init__: + - from_pretrained components not loaded during __init__ so they are registered as None; + - non from_pretrained components are created during __init__ and registered as the object itself + - Components are updated with the `update()` method: e.g. loader.update(unet=unet) or loader.update(guider=guider_spec) + - (from_pretrained) Components are loaded with the `load()` method: e.g. loader.load(component_names=["unet"]) + Args: **kwargs: Keyword arguments where keys are component names and values are component objects. + E.g., register_components(unet=unet_model, text_encoder=encoder_model) + Notes: + - Components must be created from ComponentSpec (have _diffusers_load_id attribute) + - When registering None for a component, it updates the modular_model_index.json config but sets attribute to None """ for name, module in kwargs.items(): - # current component spec component_spec = self._component_specs.get(name) if component_spec is None: logger.warning(f"ModularLoader.register_components: skipping unknown component '{name}'") continue + # check if it is the first time registration, i.e. calling from __init__ is_registered = hasattr(self, name) + # make sure the component is created from ComponentSpec if module is not None and not hasattr(module, "_diffusers_load_id"): raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") - # actual library and class name of the module - if module is not None: - library, class_name = _fetch_class_library_tuple(module) - new_component_spec = ComponentSpec.from_component(name, module) - component_spec_dict = self._component_spec_to_dict(new_component_spec) + # actual library and class name of the module + library, class_name = _fetch_class_library_tuple(module) # e.g. ("diffusers", "UNet2DConditionModel") + + # extract the loading spec from the updated component spec that'll be used as part of modular_model_index.json config + # e.g. {"repo": "stabilityai/stable-diffusion-2-1", + # "type_hint": ("diffusers", "UNet2DConditionModel"), + # "subfolder": "unet", + # "variant": None, + # "revision": None} + component_spec_dict = self._component_spec_to_dict(component_spec) else: + # if module is None, e.g. self.register_components(unet=None) during __init__ + # we do not update the spec, + # but we still need to update the modular_model_index.json config based oncomponent spec library, class_name = None, None - # if module is None, we do not update the spec, - # but we still need to update the config to make sure it's synced with the component spec - # (in the case of the first time registration, we initilize the object with component spec, and then we call register_components() to register it to config) - new_component_spec = component_spec component_spec_dict = self._component_spec_to_dict(component_spec) - - # do not register if component is not to be loaded from pretrained - if new_component_spec.default_creation_method == "from_pretrained": - register_dict = {name: (library, class_name, component_spec_dict)} - else: - register_dict = {} + register_dict = {name: (library, class_name, component_spec_dict)} # set the component as attribute # if it is not set yet, just set it and skip the process to check and warn below if not is_registered: self.register_to_config(**register_dict) - self._component_specs[name] = new_component_spec setattr(self, name, module) - if module is not None and self._component_manager is not None: + if module is not None and module._diffusers_load_id != "null" and self._component_manager is not None: self._component_manager.add(name, module, self._collection) continue @@ -1707,10 +1721,6 @@ def register_components(self, **kwargs): if current_module is module: logger.info(f"ModularLoader.register_components: {name} is already registered with same object, skipping") continue - - # it module is not an instance of the expected type, still register it but with a warning - if module is not None and component_spec.type_hint is not None and not isinstance(module, component_spec.type_hint): - logger.warning(f"ModularLoader.register_components: adding {name} with new type: {module.__class__.__name__}, previous type: {component_spec.type_hint.__name__}") # warn if unregister if current_module is not None and module is None: @@ -1718,7 +1728,7 @@ def register_components(self, **kwargs): f"ModularLoader.register_components: setting '{name}' to None " f"(was {current_module.__class__.__name__})" ) - # same type, new instance → debug + # same type, new instance → replace but send debug log elif current_module is not None \ and module is not None \ and isinstance(module, current_module.__class__) \ @@ -1728,13 +1738,12 @@ def register_components(self, **kwargs): f"(same type {type(current_module).__name__}, new instance)" ) - # save modular_model_index.json config + # update modular_model_index.json config self.register_to_config(**register_dict) - # update component spec - self._component_specs[name] = new_component_spec # finally set models setattr(self, name, module) - if module is not None and self._component_manager is not None: + # add to component manager if one is attached + if module is not None and module._diffusers_load_id != "null" and self._component_manager is not None: self._component_manager.add(name, module, self._collection) @@ -1758,6 +1767,7 @@ def __init__(self, specs: List[Union[ComponentSpec, ConfigSpec]], modular_repo: config_dict = self.load_config(modular_repo, **kwargs) for name, value in config_dict.items(): + # only update component_spec for from_pretrained components if name in self._component_specs and self._component_specs[name].default_creation_method == "from_pretrained" and isinstance(value, (tuple, list)) and len(value) == 3: library, class_name, component_spec_dict = value component_spec = self._dict_to_component_spec(name, component_spec_dict) @@ -1768,7 +1778,11 @@ def __init__(self, specs: List[Union[ComponentSpec, ConfigSpec]], modular_repo: register_components_dict = {} for name, component_spec in self._component_specs.items(): - register_components_dict[name] = None + if component_spec.default_creation_method == "from_config": + component = component_spec.create() + else: + component = None + register_components_dict[name] = component self.register_components(**register_components_dict) default_configs = {} @@ -1870,6 +1884,7 @@ def update(self, **kwargs): **kwargs: Component objects or configuration values to update: - Component objects: Must be created using ComponentSpec (e.g., `unet=new_unet, text_encoder=new_encoder`) - Configuration values: Simple values to update configuration settings (e.g., `requires_safety_checker=False`) + - ComponentSpec objects: if passed a ComponentSpec object, only support from_config spec, will call create() method to create it Raises: ValueError: If a component wasn't created using ComponentSpec (doesn't have `_diffusers_load_id` attribute) @@ -1893,22 +1908,52 @@ def update(self, **kwargs): unet=new_unet_model, requires_safety_checker=False ) + # update with ComponentSpec objects + loader.update( + guider=ComponentSpec(name="guider", type_hint=ClassifierFreeGuidance, config={"guidance_scale": 5.0}, default_creation_method="from_config") + ) ``` """ # extract component_specs_updates & config_specs_updates from `specs` - passed_components = {k: kwargs.pop(k) for k in self._component_specs if k in kwargs} + passed_component_specs = {k: kwargs.pop(k) for k in self._component_specs if k in kwargs and isinstance(kwargs[k], ComponentSpec)} + passed_components = {k: kwargs.pop(k) for k in self._component_specs if k in kwargs and not isinstance(kwargs[k], ComponentSpec)} passed_config_values = {k: kwargs.pop(k) for k in self._config_specs if k in kwargs} for name, component in passed_components.items(): if not hasattr(component, "_diffusers_load_id"): - raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") + raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") + + # YiYi TODO: remove this if we remove support for non config mixin components in `create()` method + if component._diffusers_load_id == "null" and not isinstance(component, ConfigMixin): + raise ValueError( + f"The passed component '{name}' is not supported in update() method " + f"because it is not supported in `ComponentSpec.from_component()`. " + f"Please pass a ComponentSpec object instead." + ) + current_component_spec = self._component_specs[name] + # warn if type changed + if current_component_spec.type_hint is not None and not isinstance(component, current_component_spec.type_hint): + logger.warning(f"ModularLoader.update: adding {name} with new type: {component.__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}") + # update _component_specs based on the new component + new_component_spec = ComponentSpec.from_component(name, component) + self._component_specs[name] = new_component_spec if len(kwargs) > 0: logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}") - - self.register_components(**passed_components) + created_components = {} + for name, component_spec in passed_component_specs.items(): + if component_spec.default_creation_method == "from_pretrained": + raise ValueError(f"ComponentSpec object with default_creation_method == 'from_pretrained' is not supported in update() method") + created_components[name] = component_spec.create() + current_component_spec = self._component_specs[name] + # warn if type changed + if current_component_spec.type_hint is not None and not isinstance(created_components[name], current_component_spec.type_hint): + logger.warning(f"ModularLoader.update: adding {name} with new type: {created_components[name].__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}") + # update _component_specs based on the user passed component_spec + self._component_specs[name] = component_spec + self.register_components(**passed_components, **created_components) config_to_register = {} @@ -1932,8 +1977,9 @@ def load(self, component_names: Optional[List[str]] = None, **kwargs): - a dict, e.g. torch_dtype={"unet": torch.bfloat16, "default": torch.float32} - if potentially override ComponentSpec if passed a different loading field in kwargs, e.g. `repo`, `variant`, `revision`, etc. """ + # if not specific name, load all the components with default_creation_method == "from_pretrained" if component_names is None: - component_names = list(self._component_specs.keys()) + component_names = [name for name in self._component_specs.keys() if self._component_specs[name].default_creation_method == "from_pretrained"] elif not isinstance(component_names, list): component_names = [component_names] @@ -1958,7 +2004,7 @@ def load(self, component_names: Optional[List[str]] = None, **kwargs): # check if the default is specified component_load_kwargs[key] = value["default"] try: - components_to_register[name] = spec.create(**component_load_kwargs) + components_to_register[name] = spec.load(**component_load_kwargs) except Exception as e: logger.warning(f"Failed to create component '{name}': {e}") @@ -1986,7 +2032,7 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: @classmethod @validate_hf_hub_args - def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], spec_only: bool = True, **kwargs): + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], spec_only: bool = True, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs): config_dict = cls.load_config(pretrained_model_name_or_path, **kwargs) expected_component = set(config_dict.pop("_components_names")) @@ -2010,7 +2056,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P else: # append a empty component spec for these not in modular_model_index component_specs.append(ComponentSpec(name=name, default_creation_method="from_config")) - return cls(component_specs + config_specs) + return cls(component_specs + config_specs, component_manager=component_manager, collection=collection) @staticmethod From 73ab5725c2fad4f62589554c9432c7b0dd673268 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 18 May 2025 19:09:01 +0200 Subject: [PATCH 32/38] update components manager --- .../modular_pipelines/components_manager.py | 143 +++++++++++------- 1 file changed, 89 insertions(+), 54 deletions(-) diff --git a/src/diffusers/modular_pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py index 0ace1b321e8b..88910baf90f4 100644 --- a/src/diffusers/modular_pipelines/components_manager.py +++ b/src/diffusers/modular_pipelines/components_manager.py @@ -243,78 +243,112 @@ def __init__(self): self._auto_offload_enabled = False - def _get_by_collection(self, collection: str): + def _lookup_ids(self, name=None, collection=None, load_id=None, components: OrderedDict = None): """ - Select components by collection name. + Lookup component_ids by name, collection, or load_id. """ - selected_components = {} - if collection in self.collections: - component_ids = self.collections[collection] - for component_id in component_ids: - selected_components[component_id] = self.components[component_id] - return selected_components + if components is None: + components = self.components + + if name: + ids_by_name = set() + for component_id, component in components.items(): + comp_name = "_".join(component_id.split("_")[:-1]) + if comp_name == name: + ids_by_name.add(component_id) + else: + ids_by_name = set(components.keys()) + if collection: + ids_by_collection = set() + for component_id, component in components.items(): + if component_id in self.collections[collection]: + ids_by_collection.add(component_id) + else: + ids_by_collection = set(self.collections.keys()) + if load_id: + ids_by_load_id = set() + for name, component in components.items(): + if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id == load_id: + ids_by_load_id.add(name) + else: + ids_by_load_id = set(components.keys()) - - def _get_by_load_id(self, load_id: str): - """ - Select components by its load_id. - """ - selected_components = {} - for name, component in self.components.items(): - if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id == load_id: - selected_components[name] = component - return selected_components + ids = ids_by_name.intersection(ids_by_collection).intersection(ids_by_load_id) + return ids def add(self, name, component, collection: Optional[str] = None): + + component_id = f"{name}_{uuid.uuid4()}" + # check for duplicated components for comp_id, comp in self.components.items(): if comp == component: - logger.warning(f"Component '{name}' already exists in ComponentsManager") - return comp_id + logger.warning( + f"component '{component}' already exists as '{comp_id}'" + ) + # if name is the same, use the existing component_id + if comp_id.split("_")[:-1] == component_id.split("_")[:-1]: + component_id = comp_id + break - component_id = f"{name}_{uuid.uuid4()}" + # check for duplicated load_id and warn (we do not delete for you) if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null": - components_with_same_load_id = self._get_by_load_id(component._diffusers_load_id) + components_with_same_load_id = self._lookup_ids(load_id=component._diffusers_load_id) + if components_with_same_load_id: - existing = ", ".join(components_with_same_load_id.keys()) + existing = ", ".join(components_with_same_load_id) logger.warning( - f"Component '{name}' has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. " - f"To remove a duplicate, call `components_manager.remove('')`." + f"Adding component '{component_id}', but it has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. " + f"To remove a duplicate, call `components_manager.remove('')`." ) - # add component to components manager self.components[component_id] = component self.added_time[component_id] = time.time() + if collection: if collection not in self.collections: self.collections[collection] = set() + comp_ids_in_collection = self._lookup_ids(name=name, collection=collection) + for comp_id in comp_ids_in_collection: + logger.info(f"Removing existing {name} from collection '{collection}': {comp_id}") + self.remove(comp_id) self.collections[collection].add(component_id) + logger.info(f"Added component '{name}' in collection '{collection}': {component_id}") + else: + logger.info(f"Added component '{name}' as '{component_id}'") if self._auto_offload_enabled: self.enable_auto_cpu_offload(self._auto_offload_device) - logger.info(f"Added component '{name}' to ComponentsManager as '{component_id}'") return component_id - def remove(self, name: Union[str, List[str]]): + def remove(self, component_id: str = None): - if name not in self.components: - logger.warning(f"Component '{name}' not found in ComponentsManager") + if component_id not in self.components: + logger.warning(f"Component '{component_id}' not found in ComponentsManager") return - - self.components.pop(name) - self.added_time.pop(name) + + component = self.components.pop(component_id) + self.added_time.pop(component_id) for collection in self.collections: - if name in self.collections[collection]: - self.collections[collection].remove(name) + if component_id in self.collections[collection]: + self.collections[collection].remove(component_id) if self._auto_offload_enabled: self.enable_auto_cpu_offload(self._auto_offload_device) + else: + if isinstance(component, torch.nn.Module): + component.to("cpu") + del component + import gc + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() def get(self, names: Union[str, List[str]] = None, collection: Optional[str] = None, load_id: Optional[str] = None, as_name_component_tuples: bool = False): @@ -343,16 +377,8 @@ def get(self, names: Union[str, List[str]] = None, collection: Optional[str] = N or list of (base_name, component) tuples if as_name_component_tuples=True """ - if collection: - if collection not in self.collections: - logger.warning(f"Collection '{collection}' not found in ComponentsManager") - return [] if as_name_component_tuples else {} - components = self._get_by_collection(collection) - else: - components = self.components - - if load_id: - components = self._get_by_load_id(load_id) + selected_ids = self._lookup_ids(collection=collection, load_id=load_id) + components = {k: self.components[k] for k in selected_ids} # Helper to extract base name from component_id def get_base_name(component_id): @@ -542,11 +568,11 @@ def disable_auto_cpu_offload(self): self._auto_offload_enabled = False # YiYi TODO: add quantization info - def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = None) -> Optional[Dict[str, Any]]: + def get_model_info(self, component_id: str, fields: Optional[Union[str, List[str]]] = None) -> Optional[Dict[str, Any]]: """Get comprehensive information about a component. Args: - name: Name of the component to get info for + component_id: Name of the component to get info for fields: Optional field(s) to return. Can be a string for single field or list of fields. If None, returns all fields. @@ -555,16 +581,16 @@ def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = No If fields is specified, returns only those fields. If a single field is requested as string, returns just that field's value. """ - if name not in self.components: - raise ValueError(f"Component '{name}' not found in ComponentsManager") + if component_id not in self.components: + raise ValueError(f"Component '{component_id}' not found in ComponentsManager") - component = self.components[name] + component = self.components[component_id] # Build complete info dict first info = { - "model_id": name, - "added_time": self.added_time[name], - "collection": next((coll for coll, comps in self.collections.items() if name in comps), None), + "model_id": component_id, + "added_time": self.added_time[component_id], + "collection": next((coll for coll, comps in self.collections.items() if component_id in comps), None), } # Additional info for torch.nn.Module components @@ -776,7 +802,7 @@ def from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')" ) - def get_one(self, name: Optional[str] = None, collection: Optional[str] = None, load_id: Optional[str] = None) -> Any: + def get_one(self, component_id: Optional[str] = None, name: Optional[str] = None, collection: Optional[str] = None, load_id: Optional[str] = None) -> Any: """ Get a single component by name. Raises an error if multiple components match or none are found. @@ -791,6 +817,15 @@ def get_one(self, name: Optional[str] = None, collection: Optional[str] = None, Raises: ValueError: If no components match or multiple components match """ + + # if component_id is provided, return the component + if component_id is not None and (name is not None or collection is not None or load_id is not None): + raise ValueError(" if component_id is provided, name, collection, and load_id must be None") + elif component_id is not None: + if component_id not in self.components: + raise ValueError(f"Component '{component_id}' not found in ComponentsManager") + return self.components[component_id] + results = self.get(name, collection, load_id) if not results: From 61dac3bbe4ad8d71c5239e9d6158819f99069a20 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 19 May 2025 22:39:32 +0200 Subject: [PATCH 33/38] up --- .../modular_pipelines/components_manager.py | 84 +++++++++++++------ .../modular_pipelines/modular_pipeline.py | 15 ++-- 2 files changed, 65 insertions(+), 34 deletions(-) diff --git a/src/diffusers/modular_pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py index 88910baf90f4..992353389b95 100644 --- a/src/diffusers/modular_pipelines/components_manager.py +++ b/src/diffusers/modular_pipelines/components_manager.py @@ -253,7 +253,7 @@ def _lookup_ids(self, name=None, collection=None, load_id=None, components: Orde if name: ids_by_name = set() for component_id, component in components.items(): - comp_name = "_".join(component_id.split("_")[:-1]) + comp_name = self._id_to_name(component_id) if comp_name == name: ids_by_name.add(component_id) else: @@ -264,7 +264,7 @@ def _lookup_ids(self, name=None, collection=None, load_id=None, components: Orde if component_id in self.collections[collection]: ids_by_collection.add(component_id) else: - ids_by_collection = set(self.collections.keys()) + ids_by_collection = set(components.keys()) if load_id: ids_by_load_id = set() for name, component in components.items(): @@ -276,6 +276,9 @@ def _lookup_ids(self, name=None, collection=None, load_id=None, components: Orde ids = ids_by_name.intersection(ids_by_collection).intersection(ids_by_load_id) return ids + @staticmethod + def _id_to_name(component_id: str): + return "_".join(component_id.split("_")[:-1]) def add(self, name, component, collection: Optional[str] = None): @@ -284,18 +287,24 @@ def add(self, name, component, collection: Optional[str] = None): # check for duplicated components for comp_id, comp in self.components.items(): if comp == component: - logger.warning( - f"component '{component}' already exists as '{comp_id}'" - ) - # if name is the same, use the existing component_id - if comp_id.split("_")[:-1] == component_id.split("_")[:-1]: + comp_name = self._id_to_name(comp_id) + if comp_name == name: + logger.warning( + f"component '{name}' already exists as '{comp_id}'" + ) component_id = comp_id break + else: + logger.warning( + f"Adding component '{name}' as '{component_id}', but it is duplicate of '{comp_id}'" + f"To remove a duplicate, call `components_manager.remove('')`." + ) # check for duplicated load_id and warn (we do not delete for you) if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null": components_with_same_load_id = self._lookup_ids(load_id=component._diffusers_load_id) + components_with_same_load_id = [id for id in components_with_same_load_id if id != component_id] if components_with_same_load_id: existing = ", ".join(components_with_same_load_id) @@ -311,12 +320,13 @@ def add(self, name, component, collection: Optional[str] = None): if collection: if collection not in self.collections: self.collections[collection] = set() - comp_ids_in_collection = self._lookup_ids(name=name, collection=collection) - for comp_id in comp_ids_in_collection: - logger.info(f"Removing existing {name} from collection '{collection}': {comp_id}") - self.remove(comp_id) - self.collections[collection].add(component_id) - logger.info(f"Added component '{name}' in collection '{collection}': {component_id}") + if not component_id in self.collections[collection]: + comp_ids_in_collection = self._lookup_ids(name=name, collection=collection) + for comp_id in comp_ids_in_collection: + logger.info(f"Removing existing {name} from collection '{collection}': {comp_id}") + self.remove(comp_id) + self.collections[collection].add(component_id) + logger.info(f"Added component '{name}' in collection '{collection}': {component_id}") else: logger.info(f"Added component '{name}' as '{component_id}'") @@ -590,7 +600,7 @@ def get_model_info(self, component_id: str, fields: Optional[Union[str, List[str info = { "model_id": component_id, "added_time": self.added_time[component_id], - "collection": next((coll for coll, comps in self.collections.items() if component_id in comps), None), + "collection": ", ".join([coll for coll, comps in self.collections.items() if component_id in comps]) or None, } # Additional info for torch.nn.Module components @@ -676,11 +686,19 @@ def format_device(component, info): ] max_load_id_len = max([15] + [len(str(lid)) for lid in load_ids]) if load_ids else 15 - # Collection names - collection_names = [ - next((coll for coll, comps in self.collections.items() if name in comps), "N/A") - for name in self.components.keys() - ] + # Get all collections for each component + component_collections = {} + for name in self.components.keys(): + component_collections[name] = [] + for coll, comps in self.collections.items(): + if name in comps: + component_collections[name].append(coll) + if not component_collections[name]: + component_collections[name] = ["N/A"] + + # Find the maximum collection name length + all_collections = [coll for colls in component_collections.values() for coll in colls] + max_collection_len = max(10, max(len(str(c)) for c in all_collections)) if all_collections else 10 col_widths = { "name": max(15, max(len(name) for name in simple_names)), @@ -689,7 +707,7 @@ def format_device(component, info): "dtype": 15, "size": 10, "load_id": max_load_id_len, - "collection": max(10, max(len(str(c)) for c in collection_names)) + "collection": max_collection_len } # Create the header lines @@ -718,11 +736,21 @@ def format_device(component, info): device_str = format_device(component, info) dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A" load_id = get_load_id(component) - collection = info["collection"] or "N/A" + + # Print first collection on the main line + first_collection = component_collections[name][0] if component_collections[name] else "N/A" output += f"{simple_name:<{col_widths['name']}} | {info['class_name']:<{col_widths['class']}} | " output += f"{device_str:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | " - output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {collection}\n" + output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {first_collection}\n" + + # Print additional collections on separate lines if they exist + for i in range(1, len(component_collections[name])): + collection = component_collections[name][i] + output += f"{'':<{col_widths['name']}} | {'':<{col_widths['class']}} | " + output += f"{'':<{col_widths['device']}} | {'':<{col_widths['dtype']}} | " + output += f"{'':<{col_widths['size']}} | {'':<{col_widths['load_id']}} | {collection}\n" + output += dash_line # Other components section @@ -738,9 +766,17 @@ def format_device(component, info): for name, component in others.items(): info = self.get_model_info(name) simple_name = get_simple_name(name) - collection = info["collection"] or "N/A" - output += f"{simple_name:<{col_widths['name']}} | {component.__class__.__name__:<{col_widths['class']}} | {collection}\n" + # Print first collection on the main line + first_collection = component_collections[name][0] if component_collections[name] else "N/A" + + output += f"{simple_name:<{col_widths['name']}} | {component.__class__.__name__:<{col_widths['class']}} | {first_collection}\n" + + # Print additional collections on separate lines if they exist + for i in range(1, len(component_collections[name])): + collection = component_collections[name][i] + output += f"{'':<{col_widths['name']}} | {'':<{col_widths['class']}} | {collection}\n" + output += dash_line # Add additional component info diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 1c67a3871764..36273da11f5a 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -2043,19 +2043,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P for name, value in config_dict.items(): if name in expected_component and isinstance(value, (tuple, list)) and len(value) == 3: library, class_name, component_spec_dict = value - component_spec = cls._dict_to_component_spec(name, component_spec_dict) - component_specs.append(component_spec) + # only pick up pretrained components from the repo + if component_spec_dict.get("repo", None) is not None: + component_spec = cls._dict_to_component_spec(name, component_spec_dict) + component_specs.append(component_spec) elif name in expected_config: config_specs.append(ConfigSpec(name=name, default=value)) - - for name in expected_component: - for spec in component_specs: - if spec.name == name: - break - else: - # append a empty component spec for these not in modular_model_index - component_specs.append(ComponentSpec(name=name, default_creation_method="from_config")) + return cls(component_specs + config_specs, component_manager=component_manager, collection=collection) From 4968edc5dc499d472e88c5637dc2afd968f5bcbe Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 20 May 2025 18:07:27 +0200 Subject: [PATCH 34/38] remove the duplicated components_manager file I forgot to deletee --- src/diffusers/pipelines/components_manager.py | 862 ------------------ 1 file changed, 862 deletions(-) delete mode 100644 src/diffusers/pipelines/components_manager.py diff --git a/src/diffusers/pipelines/components_manager.py b/src/diffusers/pipelines/components_manager.py deleted file mode 100644 index bdff133e22d9..000000000000 --- a/src/diffusers/pipelines/components_manager.py +++ /dev/null @@ -1,862 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections import OrderedDict -from itertools import combinations -from typing import List, Optional, Union, Dict, Any -import copy - -import torch -import time -from dataclasses import dataclass - -from ..utils import ( - is_accelerate_available, - logging, -) -from ..models.modeling_utils import ModelMixin -from .modular_pipeline_utils import ComponentSpec - - -if is_accelerate_available(): - from accelerate.hooks import ModelHook, add_hook_to_module, remove_hook_from_module - from accelerate.state import PartialState - from accelerate.utils import send_to_device - from accelerate.utils.memory import clear_device_cache - from accelerate.utils.modeling import convert_file_size_to_int - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -# YiYi Notes: copied from modeling_utils.py (decide later where to put this) -def get_memory_footprint(self, return_buffers=True): - r""" - Get the memory footprint of a model. This will return the memory footprint of the current model in bytes. Useful to - benchmark the memory footprint of the current model and design some tests. Solution inspired from the PyTorch - discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2 - - Arguments: - return_buffers (`bool`, *optional*, defaults to `True`): - Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers are - tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch norm - layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2 - """ - mem = sum([param.nelement() * param.element_size() for param in self.parameters()]) - if return_buffers: - mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()]) - mem = mem + mem_bufs - return mem - - -class CustomOffloadHook(ModelHook): - """ - A hook that offloads a model on the CPU until its forward pass is called. It ensures the model and its inputs are - on the given device. Optionally offloads other models to the CPU before the forward pass is called. - - Args: - execution_device(`str`, `int` or `torch.device`, *optional*): - The device on which the model should be executed. Will default to the MPS device if it's available, then - GPU 0 if there is a GPU, and finally to the CPU. - """ - - def __init__( - self, - execution_device: Optional[Union[str, int, torch.device]] = None, - other_hooks: Optional[List["UserCustomOffloadHook"]] = None, - offload_strategy: Optional["AutoOffloadStrategy"] = None, - ): - self.execution_device = execution_device if execution_device is not None else PartialState().default_device - self.other_hooks = other_hooks - self.offload_strategy = offload_strategy - self.model_id = None - - def set_strategy(self, offload_strategy: "AutoOffloadStrategy"): - self.offload_strategy = offload_strategy - - def add_other_hook(self, hook: "UserCustomOffloadHook"): - """ - Add a hook to the list of hooks to consider for offloading. - """ - if self.other_hooks is None: - self.other_hooks = [] - self.other_hooks.append(hook) - - def init_hook(self, module): - return module.to("cpu") - - def pre_forward(self, module, *args, **kwargs): - if module.device != self.execution_device: - if self.other_hooks is not None: - hooks_to_offload = [hook for hook in self.other_hooks if hook.model.device == self.execution_device] - # offload all other hooks - start_time = time.perf_counter() - if self.offload_strategy is not None: - hooks_to_offload = self.offload_strategy( - hooks=hooks_to_offload, - model_id=self.model_id, - model=module, - execution_device=self.execution_device, - ) - end_time = time.perf_counter() - logger.info( - f" time taken to apply offload strategy for {self.model_id}: {(end_time - start_time):.2f} seconds" - ) - - for hook in hooks_to_offload: - logger.info( - f"moving {self.model_id} to {self.execution_device}, offloading {hook.model_id} to cpu" - ) - hook.offload() - - if hooks_to_offload: - clear_device_cache() - module.to(self.execution_device) - return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device) - - -class UserCustomOffloadHook: - """ - A simple hook grouping a model and a `CustomOffloadHook`, which provides easy APIs for to call the init method of - the hook or remove it entirely. - """ - - def __init__(self, model_id, model, hook): - self.model_id = model_id - self.model = model - self.hook = hook - - def offload(self): - self.hook.init_hook(self.model) - - def attach(self): - add_hook_to_module(self.model, self.hook) - self.hook.model_id = self.model_id - - def remove(self): - remove_hook_from_module(self.model) - self.hook.model_id = None - - def add_other_hook(self, hook: "UserCustomOffloadHook"): - self.hook.add_other_hook(hook) - - -def custom_offload_with_hook( - model_id: str, - model: torch.nn.Module, - execution_device: Union[str, int, torch.device] = None, - offload_strategy: Optional["AutoOffloadStrategy"] = None, -): - hook = CustomOffloadHook(execution_device=execution_device, offload_strategy=offload_strategy) - user_hook = UserCustomOffloadHook(model_id=model_id, model=model, hook=hook) - user_hook.attach() - return user_hook - - -class AutoOffloadStrategy: - """ - Offload strategy that should be used with `CustomOffloadHook` to automatically offload models to the CPU based on - the available memory on the device. - """ - - def __init__(self, memory_reserve_margin="3GB"): - self.memory_reserve_margin = convert_file_size_to_int(memory_reserve_margin) - - def __call__(self, hooks, model_id, model, execution_device): - if len(hooks) == 0: - return [] - - current_module_size = get_memory_footprint(model) - - mem_on_device = torch.cuda.mem_get_info(execution_device.index)[0] - mem_on_device = mem_on_device - self.memory_reserve_margin - if current_module_size < mem_on_device: - return [] - - min_memory_offload = current_module_size - mem_on_device - logger.info(f" search for models to offload in order to free up {min_memory_offload / 1024**3:.2f} GB memory") - - # exlucde models that's not currently loaded on the device - module_sizes = dict( - sorted( - {hook.model_id: get_memory_footprint(hook.model) for hook in hooks}.items(), - key=lambda x: x[1], - reverse=True, - ) - ) - - def search_best_candidate(module_sizes, min_memory_offload): - """ - search the optimal combination of models to offload to cpu, given a dictionary of module sizes and a - minimum memory offload size. the combination of models should add up to the smallest modulesize that is - larger than `min_memory_offload` - """ - model_ids = list(module_sizes.keys()) - best_candidate = None - best_size = float("inf") - for r in range(1, len(model_ids) + 1): - for candidate_model_ids in combinations(model_ids, r): - candidate_size = sum( - module_sizes[candidate_model_id] for candidate_model_id in candidate_model_ids - ) - if candidate_size < min_memory_offload: - continue - else: - if best_candidate is None or candidate_size < best_size: - best_candidate = candidate_model_ids - best_size = candidate_size - - return best_candidate - - best_offload_model_ids = search_best_candidate(module_sizes, min_memory_offload) - - if best_offload_model_ids is None: - # if no combination is found, meaning that we cannot meet the memory requirement, offload all models - logger.warning("no combination of models to offload to cpu is found, offloading all models") - hooks_to_offload = hooks - else: - hooks_to_offload = [hook for hook in hooks if hook.model_id in best_offload_model_ids] - - return hooks_to_offload - - - -from .modular_pipeline_utils import ComponentSpec -import uuid -class ComponentsManager: - def __init__(self): - self.components = OrderedDict() - self.added_time = OrderedDict() # Store when components were added - self.collections = OrderedDict() # collection_name -> set of component_names - self.model_hooks = None - self._auto_offload_enabled = False - - - def _get_by_collection(self, collection: str): - """ - Select components by collection name. - """ - selected_components = {} - if collection in self.collections: - component_ids = self.collections[collection] - for component_id in component_ids: - selected_components[component_id] = self.components[component_id] - return selected_components - - - def _get_by_load_id(self, load_id: str): - """ - Select components by its load_id. - """ - selected_components = {} - for name, component in self.components.items(): - if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id == load_id: - selected_components[name] = component - return selected_components - - - def add(self, name, component, collection: Optional[str] = None): - - for comp_id, comp in self.components.items(): - if comp == component: - logger.warning(f"Component '{name}' already exists in ComponentsManager") - return comp_id - - component_id = f"{name}_{uuid.uuid4()}" - - if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null": - components_with_same_load_id = self._get_by_load_id(component._diffusers_load_id) - if components_with_same_load_id: - existing = ", ".join(components_with_same_load_id.keys()) - logger.warning( - f"Component '{name}' has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. " - f"To remove a duplicate, call `components_manager.remove('')`." - ) - - - # add component to components manager - self.components[component_id] = component - self.added_time[component_id] = time.time() - if collection: - if collection not in self.collections: - self.collections[collection] = set() - self.collections[collection].add(component_id) - - if self._auto_offload_enabled: - self.enable_auto_cpu_offload(self._auto_offload_device) - - logger.info(f"Added component '{name}' to ComponentsManager as '{component_id}'") - return component_id - - - def remove(self, name: Union[str, List[str]]): - - if name not in self.components: - logger.warning(f"Component '{name}' not found in ComponentsManager") - return - - self.components.pop(name) - self.added_time.pop(name) - - for collection in self.collections: - if name in self.collections[collection]: - self.collections[collection].remove(name) - - if self._auto_offload_enabled: - self.enable_auto_cpu_offload(self._auto_offload_device) - - def get(self, names: Union[str, List[str]] = None, collection: Optional[str] = None, load_id: Optional[str] = None, - as_name_component_tuples: bool = False): - """ - Select components by name with simple pattern matching. - - Args: - names: Component name(s) or pattern(s) - Patterns: - - "unet" : match any component with base name "unet" (e.g., unet_123abc) - - "!unet" : everything except components with base name "unet" - - "unet*" : anything with base name starting with "unet" - - "!unet*" : anything with base name NOT starting with "unet" - - "*unet*" : anything with base name containing "unet" - - "!*unet*" : anything with base name NOT containing "unet" - - "refiner|vae|unet" : anything with base name exactly matching "refiner", "vae", or "unet" - - "!refiner|vae|unet" : anything with base name NOT exactly matching "refiner", "vae", or "unet" - - "unet*|vae*" : anything with base name starting with "unet" OR starting with "vae" - collection: Optional collection to filter by - load_id: Optional load_id to filter by - as_name_component_tuples: If True, returns a list of (name, component) tuples using base names - instead of a dictionary with component IDs as keys - - Returns: - Dictionary mapping component IDs to components, - or list of (base_name, component) tuples if as_name_component_tuples=True - """ - - if collection: - if collection not in self.collections: - logger.warning(f"Collection '{collection}' not found in ComponentsManager") - return [] if as_name_component_tuples else {} - components = self._get_by_collection(collection) - else: - components = self.components - - if load_id: - components = self._get_by_load_id(load_id) - - # Helper to extract base name from component_id - def get_base_name(component_id): - parts = component_id.split('_') - # If the last part looks like a UUID, remove it - if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]: - return '_'.join(parts[:-1]) - return component_id - - if names is None: - if as_name_component_tuples: - return [(get_base_name(comp_id), comp) for comp_id, comp in components.items()] - else: - return components - - # Create mapping from component_id to base_name for all components - base_names = {comp_id: get_base_name(comp_id) for comp_id in components.keys()} - - def matches_pattern(component_id, pattern, exact_match=False): - """ - Helper function to check if a component matches a pattern based on its base name. - - Args: - component_id: The component ID to check - pattern: The pattern to match against - exact_match: If True, only exact matches to base_name are considered - """ - base_name = base_names[component_id] - - # Exact match with base name - if exact_match: - return pattern == base_name - - # Prefix match (ends with *) - elif pattern.endswith('*'): - prefix = pattern[:-1] - return base_name.startswith(prefix) - - # Contains match (starts with *) - elif pattern.startswith('*'): - search = pattern[1:-1] if pattern.endswith('*') else pattern[1:] - return search in base_name - - # Exact match (no wildcards) - else: - return pattern == base_name - - if isinstance(names, str): - # Check if this is a "not" pattern - is_not_pattern = names.startswith('!') - if is_not_pattern: - names = names[1:] # Remove the ! prefix - - # Handle OR patterns (containing |) - if '|' in names: - terms = names.split('|') - matches = {} - - for comp_id, comp in components.items(): - # For OR patterns with exact names (no wildcards), we do exact matching on base names - exact_match = all(not (term.startswith('*') or term.endswith('*')) for term in terms) - - # Check if any of the terms match this component - should_include = any(matches_pattern(comp_id, term, exact_match) for term in terms) - - # Flip the decision if this is a NOT pattern - if is_not_pattern: - should_include = not should_include - - if should_include: - matches[comp_id] = comp - - log_msg = "NOT " if is_not_pattern else "" - match_type = "exactly matching" if exact_match else "matching any of patterns" - logger.info(f"Getting components {log_msg}{match_type} {terms}: {list(matches.keys())}") - - # Try exact match with a base name - elif any(names == base_name for base_name in base_names.values()): - # Find all components with this base name - matches = { - comp_id: comp for comp_id, comp in components.items() - if (base_names[comp_id] == names) != is_not_pattern - } - - if is_not_pattern: - logger.info(f"Getting all components except those with base name '{names}': {list(matches.keys())}") - else: - logger.info(f"Getting components with base name '{names}': {list(matches.keys())}") - - # Prefix match (ends with *) - elif names.endswith('*'): - prefix = names[:-1] - matches = { - comp_id: comp for comp_id, comp in components.items() - if base_names[comp_id].startswith(prefix) != is_not_pattern - } - if is_not_pattern: - logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}") - else: - logger.info(f"Getting components starting with '{prefix}': {list(matches.keys())}") - - # Contains match (starts with *) - elif names.startswith('*'): - search = names[1:-1] if names.endswith('*') else names[1:] - matches = { - comp_id: comp for comp_id, comp in components.items() - if (search in base_names[comp_id]) != is_not_pattern - } - if is_not_pattern: - logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}") - else: - logger.info(f"Getting components containing '{search}': {list(matches.keys())}") - - # Substring match (no wildcards, but not an exact component name) - elif any(names in base_name for base_name in base_names.values()): - matches = { - comp_id: comp for comp_id, comp in components.items() - if (names in base_names[comp_id]) != is_not_pattern - } - if is_not_pattern: - logger.info(f"Getting components NOT containing '{names}': {list(matches.keys())}") - else: - logger.info(f"Getting components containing '{names}': {list(matches.keys())}") - - else: - raise ValueError(f"Component or pattern '{names}' not found in ComponentsManager") - - if not matches: - raise ValueError(f"No components found matching pattern '{names}'") - - if as_name_component_tuples: - return [(base_names[comp_id], comp) for comp_id, comp in matches.items()] - else: - return matches - - elif isinstance(names, list): - results = {} - for name in names: - result = self.get(name, collection, load_id, as_name_component_tuples=False) - results.update(result) - - if as_name_component_tuples: - return [(base_names[comp_id], comp) for comp_id, comp in results.items()] - else: - return results - - else: - raise ValueError(f"Invalid type for names: {type(names)}") - - def enable_auto_cpu_offload(self, device: Union[str, int, torch.device]="cuda", memory_reserve_margin="3GB"): - for name, component in self.components.items(): - if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"): - remove_hook_from_module(component, recurse=True) - - self.disable_auto_cpu_offload() - offload_strategy = AutoOffloadStrategy(memory_reserve_margin=memory_reserve_margin) - device = torch.device(device) - if device.index is None: - device = torch.device(f"{device.type}:{0}") - all_hooks = [] - for name, component in self.components.items(): - if isinstance(component, torch.nn.Module): - hook = custom_offload_with_hook(name, component, device, offload_strategy=offload_strategy) - all_hooks.append(hook) - - for hook in all_hooks: - other_hooks = [h for h in all_hooks if h is not hook] - for other_hook in other_hooks: - if other_hook.hook.execution_device == hook.hook.execution_device: - hook.add_other_hook(other_hook) - - self.model_hooks = all_hooks - self._auto_offload_enabled = True - self._auto_offload_device = device - - def disable_auto_cpu_offload(self): - if self.model_hooks is None: - self._auto_offload_enabled = False - return - - for hook in self.model_hooks: - hook.offload() - hook.remove() - if self.model_hooks: - clear_device_cache() - self.model_hooks = None - self._auto_offload_enabled = False - - # YiYi TODO: add quantization info - def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = None) -> Optional[Dict[str, Any]]: - """Get comprehensive information about a component. - - Args: - name: Name of the component to get info for - fields: Optional field(s) to return. Can be a string for single field or list of fields. - If None, returns all fields. - - Returns: - Dictionary containing requested component metadata. - If fields is specified, returns only those fields. - If a single field is requested as string, returns just that field's value. - """ - if name not in self.components: - raise ValueError(f"Component '{name}' not found in ComponentsManager") - - component = self.components[name] - - # Build complete info dict first - info = { - "model_id": name, - "added_time": self.added_time[name], - "collection": next((coll for coll, comps in self.collections.items() if name in comps), None), - } - - # Additional info for torch.nn.Module components - if isinstance(component, torch.nn.Module): - # Check for hook information - has_hook = hasattr(component, "_hf_hook") - execution_device = None - if has_hook and hasattr(component._hf_hook, "execution_device"): - execution_device = component._hf_hook.execution_device - - info.update({ - "class_name": component.__class__.__name__, - "size_gb": get_memory_footprint(component) / (1024**3), - "adapters": None, # Default to None - "has_hook": has_hook, - "execution_device": execution_device, - }) - - # Get adapters if applicable - if hasattr(component, "peft_config"): - info["adapters"] = list(component.peft_config.keys()) - - # Check for IP-Adapter scales - if hasattr(component, "_load_ip_adapter_weights") and hasattr(component, "attn_processors"): - processors = copy.deepcopy(component.attn_processors) - # First check if any processor is an IP-Adapter - processor_types = [v.__class__.__name__ for v in processors.values()] - if any("IPAdapter" in ptype for ptype in processor_types): - # Then get scales only from IP-Adapter processors - scales = { - k: v.scale - for k, v in processors.items() - if hasattr(v, "scale") and "IPAdapter" in v.__class__.__name__ - } - if scales: - info["ip_adapter"] = summarize_dict_by_value_and_parts(scales) - - # If fields specified, filter info - if fields is not None: - if isinstance(fields, str): - # Single field requested, return just that value - return {fields: info.get(fields)} - else: - # List of fields requested, return dict with just those fields - return {k: v for k, v in info.items() if k in fields} - - return info - - def __repr__(self): - # Helper to get simple name without UUID - def get_simple_name(name): - # Extract the base name by splitting on underscore and taking first part - # This assumes names are in format "name_uuid" - parts = name.split('_') - # If we have at least 2 parts and the last part looks like a UUID, remove it - if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]: - return '_'.join(parts[:-1]) - return name - - # Extract load_id if available - def get_load_id(component): - if hasattr(component, "_diffusers_load_id"): - return component._diffusers_load_id - return "N/A" - - # Format device info compactly - def format_device(component, info): - if not info["has_hook"]: - return str(getattr(component, 'device', 'N/A')) - else: - device = str(getattr(component, 'device', 'N/A')) - exec_device = str(info['execution_device'] or 'N/A') - return f"{device}({exec_device})" - - # Get all simple names to calculate width - simple_names = [get_simple_name(id) for id in self.components.keys()] - - # Get max length of load_ids for models - load_ids = [ - get_load_id(component) - for component in self.components.values() - if isinstance(component, torch.nn.Module) and hasattr(component, "_diffusers_load_id") - ] - max_load_id_len = max([15] + [len(str(lid)) for lid in load_ids]) if load_ids else 15 - - # Collection names - collection_names = [ - next((coll for coll, comps in self.collections.items() if name in comps), "N/A") - for name in self.components.keys() - ] - - col_widths = { - "name": max(15, max(len(name) for name in simple_names)), - "class": max(25, max(len(component.__class__.__name__) for component in self.components.values())), - "device": 15, # Reduced since using more compact format - "dtype": 15, - "size": 10, - "load_id": max_load_id_len, - "collection": max(10, max(len(str(c)) for c in collection_names)) - } - - # Create the header lines - sep_line = "=" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n" - dash_line = "-" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n" - - output = "Components:\n" + sep_line - - # Separate components into models and others - models = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)} - others = {k: v for k, v in self.components.items() if not isinstance(v, torch.nn.Module)} - - # Models section - if models: - output += "Models:\n" + dash_line - # Column headers - output += f"{'Name':<{col_widths['name']}} | {'Class':<{col_widths['class']}} | " - output += f"{'Device':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | " - output += f"{'Size (GB)':<{col_widths['size']}} | {'Load ID':<{col_widths['load_id']}} | Collection\n" - output += dash_line - - # Model entries - for name, component in models.items(): - info = self.get_model_info(name) - simple_name = get_simple_name(name) - device_str = format_device(component, info) - dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A" - load_id = get_load_id(component) - collection = info["collection"] or "N/A" - - output += f"{simple_name:<{col_widths['name']}} | {info['class_name']:<{col_widths['class']}} | " - output += f"{device_str:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | " - output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {collection}\n" - output += dash_line - - # Other components section - if others: - if models: # Add extra newline if we had models section - output += "\n" - output += "Other Components:\n" + dash_line - # Column headers for other components - output += f"{'Name':<{col_widths['name']}} | {'Class':<{col_widths['class']}} | Collection\n" - output += dash_line - - # Other component entries - for name, component in others.items(): - info = self.get_model_info(name) - simple_name = get_simple_name(name) - collection = info["collection"] or "N/A" - - output += f"{simple_name:<{col_widths['name']}} | {component.__class__.__name__:<{col_widths['class']}} | {collection}\n" - output += dash_line - - # Add additional component info - output += "\nAdditional Component Info:\n" + "=" * 50 + "\n" - for name in self.components: - info = self.get_model_info(name) - if info is not None and (info.get("adapters") is not None or info.get("ip_adapter")): - simple_name = get_simple_name(name) - output += f"\n{simple_name}:\n" - if info.get("adapters") is not None: - output += f" Adapters: {info['adapters']}\n" - if info.get("ip_adapter"): - output += f" IP-Adapter: Enabled\n" - output += f" Added Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(info['added_time']))}\n" - - return output - - def from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs): - """ - Load components from a pretrained model and add them to the manager. - - Args: - pretrained_model_name_or_path (str): The path or identifier of the pretrained model - prefix (str, optional): Prefix to add to all component names loaded from this model. - If provided, components will be named as "{prefix}_{component_name}" - **kwargs: Additional arguments to pass to DiffusionPipeline.from_pretrained() - """ - subfolder = kwargs.pop("subfolder", None) - # YiYi TODO: extend AutoModel to support non-diffusers models - if subfolder: - from ..models import AutoModel - component = AutoModel.from_pretrained(pretrained_model_name_or_path, subfolder=subfolder, **kwargs) - component_name = f"{prefix}_{subfolder}" if prefix else subfolder - if component_name not in self.components: - self.add(component_name, component) - else: - logger.warning( - f"Component '{component_name}' already exists in ComponentsManager and will not be added. To add it, either:\n" - f"1. remove the existing component with remove('{component_name}')\n" - f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')" - ) - else: - from ..pipelines.pipeline_utils import DiffusionPipeline - pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs) - for name, component in pipe.components.items(): - - if component is None: - continue - - # Add prefix if specified - component_name = f"{prefix}_{name}" if prefix else name - - if component_name not in self.components: - self.add(component_name, component) - else: - logger.warning( - f"Component '{component_name}' already exists in ComponentsManager and will not be added. To add it, either:\n" - f"1. remove the existing component with remove('{component_name}')\n" - f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')" - ) - - def get_one(self, name: Optional[str] = None, collection: Optional[str] = None, load_id: Optional[str] = None) -> Any: - """ - Get a single component by name. Raises an error if multiple components match or none are found. - - Args: - name: Component name or pattern - collection: Optional collection to filter by - load_id: Optional load_id to filter by - - Returns: - A single component - - Raises: - ValueError: If no components match or multiple components match - """ - results = self.get(name, collection, load_id) - - if not results: - raise ValueError(f"No components found matching '{name}'") - - if len(results) > 1: - raise ValueError(f"Multiple components found matching '{name}': {list(results.keys())}") - - return next(iter(results.values())) - -def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]: - """Summarizes a dictionary by finding common prefixes that share the same value. - - For a dictionary with dot-separated keys like: - { - 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': [0.6], - 'down_blocks.1.attentions.1.transformer_blocks.1.attn2.processor': [0.6], - 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': [0.3], - } - - Returns a dictionary where keys are the shortest common prefixes and values are their shared values: - { - 'down_blocks': [0.6], - 'up_blocks': [0.3] - } - """ - # First group by values - convert lists to tuples to make them hashable - value_to_keys = {} - for key, value in d.items(): - value_tuple = tuple(value) if isinstance(value, list) else value - if value_tuple not in value_to_keys: - value_to_keys[value_tuple] = [] - value_to_keys[value_tuple].append(key) - - def find_common_prefix(keys: List[str]) -> str: - """Find the shortest common prefix among a list of dot-separated keys.""" - if not keys: - return "" - if len(keys) == 1: - return keys[0] - - # Split all keys into parts - key_parts = [k.split('.') for k in keys] - - # Find how many initial parts are common - common_length = 0 - for parts in zip(*key_parts): - if len(set(parts)) == 1: # All parts at this position are the same - common_length += 1 - else: - break - - if common_length == 0: - return "" - - # Return the common prefix - return '.'.join(key_parts[0][:common_length]) - - # Create summary by finding common prefixes for each value group - summary = {} - for value_tuple, keys in value_to_keys.items(): - prefix = find_common_prefix(keys) - if prefix: # Only add if we found a common prefix - # Convert tuple back to list if it was originally a list - value = list(value_tuple) if isinstance(d[keys[0]], list) else value_tuple - summary[prefix] = value - else: - summary[""] = value # Use empty string if no common prefix - - return summary From de6ab6b49d17b9e735638f9f21df9b173fd2d5b0 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 20 May 2025 18:07:58 +0200 Subject: [PATCH 35/38] fix import in block mapping --- .../stable_diffusion_xl/modular_pipeline_block_mappings.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_block_mappings.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_block_mappings.py index c739a24e9759..6d909ab5a4a0 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_block_mappings.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_block_mappings.py @@ -41,11 +41,11 @@ StableDiffusionXLInpaintVaeEncoderStep, StableDiffusionXLIPAdapterStep ) -from .after_denoise import ( +from .decoders import ( StableDiffusionXLDecodeStep, - StableDiffusionXLInpaintDecodeStep + StableDiffusionXLInpaintDecodeStep, + StableDiffusionXLAutoDecodeStep ) -from .after_denoise import StableDiffusionXLAutoDecodeStep # YiYi notes: comment out for now, work on this later From eb9415031a54b6aba5a44a52ead90197502f806f Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 20 May 2025 18:08:28 +0200 Subject: [PATCH 36/38] add a to-do for modular loader --- src/diffusers/modular_pipelines/modular_pipeline.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 36273da11f5a..ef725c32f4f9 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -1638,9 +1638,10 @@ def __repr__(self): # YiYi TODO: -# 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess) -# 2. do we need ConfigSpec? seems pretty unnecessrary for loader, can just add and kwargs to the loader -# 3. add validator for methods where we accpet kwargs to be passed to from_pretrained() +# 1. move the modular_repo arg and the logic to fetch info from repo out of __init__ so that __init__ alwasy create an default modular_model_index config +# 2. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess) +# 3. do we need ConfigSpec? seems pretty unnecessrary for loader, can just add and kwargs to the loader +# 4. add validator for methods where we accpet kwargs to be passed to from_pretrained() class ModularLoader(ConfigMixin, PushToHubMixin): """ Base class for all Modular pipelines loaders. From 1b89ac144c6eba7d66ca34924c83a6323944ccac Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 20 May 2025 18:10:06 +0200 Subject: [PATCH 37/38] prepare_latents_img2img pipeline method -> function, maybe do the same for others? --- .../stable_diffusion_xl/before_denoise.py | 168 +++++++++--------- 1 file changed, 83 insertions(+), 85 deletions(-) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py index 8f083f1870e7..07f096249c0d 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py @@ -127,6 +127,86 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") +def prepare_latents_img2img(vae, scheduler, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True): + + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + latents_mean = latents_std = None + if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None: + latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None: + latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1) + # make sure the VAE is in float32 mode, as it overflows in float16 + if vae.config.force_upcast: + image = image.float() + vae.to(dtype=torch.float32) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + + init_latents = [ + retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(vae.encode(image), generator=generator) + + if vae.config.force_upcast: + vae.to(dtype) + + init_latents = init_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=device, dtype=dtype) + latents_std = latents_std.to(device=device, dtype=dtype) + init_latents = (init_latents - latents_mean) * vae.config.scaling_factor / latents_std + else: + init_latents = vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + if add_noise: + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # get latents + init_latents = scheduler.add_noise(init_latents, noise, timestep) + + latents = init_latents + + return latents + + class StableDiffusionXLInputStep(PipelineBlock): model_name = "stable-diffusion-xl" @@ -751,89 +831,6 @@ def intermediates_inputs(self) -> List[InputParam]: def intermediates_outputs(self) -> List[OutputParam]: return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process")] - # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents with self -> components - # YiYi TODO: refactor using _encode_vae_image - @staticmethod - def prepare_latents_img2img( - components, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True - ): - if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): - raise ValueError( - f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" - ) - - image = image.to(device=device, dtype=dtype) - - batch_size = batch_size * num_images_per_prompt - - if image.shape[1] == 4: - init_latents = image - - else: - latents_mean = latents_std = None - if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: - latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: - latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) - # make sure the VAE is in float32 mode, as it overflows in float16 - if components.vae.config.force_upcast: - image = image.float() - components.vae.to(dtype=torch.float32) - - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - elif isinstance(generator, list): - if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: - image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) - elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " - ) - - init_latents = [ - retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(batch_size) - ] - init_latents = torch.cat(init_latents, dim=0) - else: - init_latents = retrieve_latents(components.vae.encode(image), generator=generator) - - if components.vae.config.force_upcast: - components.vae.to(dtype) - - init_latents = init_latents.to(dtype) - if latents_mean is not None and latents_std is not None: - latents_mean = latents_mean.to(device=device, dtype=dtype) - latents_std = latents_std.to(device=device, dtype=dtype) - init_latents = (init_latents - latents_mean) * components.vae.config.scaling_factor / latents_std - else: - init_latents = components.vae.config.scaling_factor * init_latents - - if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: - # expand init_latents for batch_size - additional_image_per_prompt = batch_size // init_latents.shape[0] - init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) - elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." - ) - else: - init_latents = torch.cat([init_latents], dim=0) - - if add_noise: - shape = init_latents.shape - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - # get latents - init_latents = components.scheduler.add_noise(init_latents, noise, timestep) - - latents = init_latents - - return latents - @torch.no_grad() def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) @@ -842,8 +839,9 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt block_state.device = components._execution_device block_state.add_noise = True if block_state.denoising_start is None else False if block_state.latents is None: - block_state.latents = self.prepare_latents_img2img( - components, + block_state.latents = prepare_latents_img2img( + components.vae, + components.scheduler, block_state.image_latents, block_state.latent_timestep, block_state.batch_size, From d136ae36c87b66cb6e53c30098d09cd307641588 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 20 May 2025 18:11:05 +0200 Subject: [PATCH 38/38] update input for loop blocks, do not need to include intermediate --- .../stable_diffusion_xl/denoise.py | 28 ------------------- 1 file changed, 28 deletions(-) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py index b29920764acb..bc567a6b034f 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py @@ -68,18 +68,11 @@ def intermediates_inputs(self) -> List[str]: ), ] - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("scaled_latents", type_hint=torch.Tensor, description="The scaled latents input for denoiser")] - - - @torch.no_grad() def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) - return components, block_state # loop step (1): prepare latent input for denoiser (with inpainting) @@ -120,9 +113,6 @@ def intermediates_inputs(self) -> List[str]: ), ] - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("scaled_latents", type_hint=torch.Tensor, description="The scaled latents input for denoiser")] @staticmethod def check_inputs(components, block_state): @@ -187,12 +177,6 @@ def inputs(self) -> List[Tuple[str, Any]]: @property def intermediates_inputs(self) -> List[str]: return [ - InputParam( - "scaled_latents", - required=True, - type_hint=torch.Tensor, - description="The prepared latents input to use for the denoiser. Can be generated in latent step within the denoise loop." - ), InputParam( "num_inference_steps", required=True, @@ -319,12 +303,6 @@ def intermediates_inputs(self) -> List[str]: type_hint=List[float], description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step." ), - InputParam( - "scaled_latents", - required=True, - type_hint=torch.Tensor, - description="The prepared latents input to use for the denoiser. Can be generated in latent step within the denoise loop." - ), InputParam( "timestep_cond", type_hint=Optional[torch.Tensor], @@ -492,12 +470,6 @@ def inputs(self) -> List[Tuple[str, Any]]: def intermediates_inputs(self) -> List[str]: return [ InputParam("generator"), - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." - ), ] @property