From 18de3adad135a83d7e7bb675b60eaaddfe2530c0 Mon Sep 17 00:00:00 2001 From: lavinal712 Date: Thu, 30 Jan 2025 03:00:59 +0000 Subject: [PATCH 01/19] run control-lora on diffusers --- src/diffusers/loaders/single_file_utils.py | 144 ++ .../models/controlnets/controlnet_lora.py | 593 ++++++++ .../pipelines/control_lora/__init__.py | 66 + .../pipeline_control_lora_sd_xl.py | 1272 +++++++++++++++++ 4 files changed, 2075 insertions(+) create mode 100644 src/diffusers/models/controlnets/controlnet_lora.py create mode 100644 src/diffusers/pipelines/control_lora/__init__.py create mode 100644 src/diffusers/pipelines/control_lora/pipeline_control_lora_sd_xl.py diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 731b7b87f625..ba00911b40ef 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -1311,6 +1311,150 @@ def convert_controlnet_checkpoint( return new_checkpoint +def convert_control_lora_checkpoint( + checkpoint, + config, + **kwargs, +): + # Return checkpoint if it's already been converted + if "time_embedding.linear_1.weight" in checkpoint: + return checkpoint + # Some controlnet ckpt files are distributed independently from the rest of the + # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ + if "time_embed.0.weight" in checkpoint: + controlnet_state_dict = checkpoint + + else: + controlnet_state_dict = {} + keys = list(checkpoint.keys()) + controlnet_key = LDM_CONTROLNET_KEY + for key in keys: + if key.startswith(controlnet_key): + controlnet_state_dict[key.replace(controlnet_key, "")] = checkpoint.get(key) + else: + controlnet_state_dict[key] = checkpoint.get(key) + + new_checkpoint = {} + ldm_controlnet_keys = DIFFUSERS_TO_LDM_MAPPING["controlnet"]["layers"] + for diffusers_key, ldm_key in ldm_controlnet_keys.items(): + if ldm_key not in controlnet_state_dict: + continue + new_checkpoint[diffusers_key] = controlnet_state_dict[ldm_key] + for k, v in controlnet_state_dict.items(): + if "time_embed.0" in k: + new_checkpoint[k.replace("time_embed.0", "time_embedding.linear_1")] = v + elif "time_embed.2" in k: + new_checkpoint[k.replace("time_embed.2", "time_embedding.linear_2")] = v + elif "input_blocks.0.0" in k: + new_checkpoint[k.replace("input_blocks.0.0", "conv_in")] = v + elif "label_emb.0.0" in k: + new_checkpoint[k.replace("label_emb.0.0", "add_embedding.linear_1")] = v + elif "label_emb.0.2" in k: + new_checkpoint[k.replace("label_emb.0.2", "add_embedding.linear_2")] = v + elif "input_blocks.3.0.op" in k: + new_checkpoint[k.replace("input_blocks.3.0.op", "down_blocks.0.downsamplers.0.conv")] = v + elif "input_blocks.6.0.op" in k: + new_checkpoint[k.replace("input_blocks.6.0.op", "down_blocks.1.downsamplers.0.conv")] = v + + # Retrieves the keys for the input blocks only + num_input_blocks = len( + {".".join(layer.split(".")[:2]) for layer in controlnet_state_dict if "input_blocks" in layer} + ) + input_blocks = { + layer_id: [key for key in controlnet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Down blocks + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + update_unet_resnet_ldm_to_diffusers( + resnets, + new_checkpoint, + controlnet_state_dict, + {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}, + ) + + if f"input_blocks.{i}.0.op.weight" in controlnet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = controlnet_state_dict.get( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = controlnet_state_dict.get( + f"input_blocks.{i}.0.op.bias" + ) + + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + if attentions: + update_unet_attention_ldm_to_diffusers( + attentions, + new_checkpoint, + controlnet_state_dict, + {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}, + ) + + # controlnet down blocks + for i in range(num_input_blocks): + new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = controlnet_state_dict.get(f"zero_convs.{i}.0.weight") + new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = controlnet_state_dict.get(f"zero_convs.{i}.0.bias") + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len( + {".".join(layer.split(".")[:2]) for layer in controlnet_state_dict if "middle_block" in layer} + ) + middle_blocks = { + layer_id: [key for key in controlnet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Mid blocks + for key in middle_blocks.keys(): + diffusers_key = max(key - 1, 0) + if key % 2 == 0: + update_unet_resnet_ldm_to_diffusers( + middle_blocks[key], + new_checkpoint, + controlnet_state_dict, + mapping={"old": f"middle_block.{key}", "new": f"mid_block.resnets.{diffusers_key}"}, + ) + else: + update_unet_attention_ldm_to_diffusers( + middle_blocks[key], + new_checkpoint, + controlnet_state_dict, + mapping={"old": f"middle_block.{key}", "new": f"mid_block.attentions.{diffusers_key}"}, + ) + + # mid block + new_checkpoint["controlnet_mid_block.weight"] = controlnet_state_dict.get("middle_block_out.0.weight") + new_checkpoint["controlnet_mid_block.bias"] = controlnet_state_dict.get("middle_block_out.0.bias") + + # controlnet cond embedding blocks + cond_embedding_blocks = { + ".".join(layer.split(".")[:2]) + for layer in controlnet_state_dict + if "input_hint_block" in layer and ("input_hint_block.0" not in layer) and ("input_hint_block.14" not in layer) + } + num_cond_embedding_blocks = len(cond_embedding_blocks) + + for idx in range(1, num_cond_embedding_blocks + 1): + diffusers_idx = idx - 1 + cond_block_id = 2 * idx + + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.weight"] = controlnet_state_dict.get( + f"input_hint_block.{cond_block_id}.weight" + ) + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.bias"] = controlnet_state_dict.get( + f"input_hint_block.{cond_block_id}.bias" + ) + + return new_checkpoint + + def convert_ldm_vae_checkpoint(checkpoint, config): # extract state dict for VAE # remove the LDM_VAE_KEY prefix from the ldm checkpoint keys so that it is easier to map them to diffusers keys diff --git a/src/diffusers/models/controlnets/controlnet_lora.py b/src/diffusers/models/controlnets/controlnet_lora.py new file mode 100644 index 000000000000..65f94cf65a56 --- /dev/null +++ b/src/diffusers/models/controlnets/controlnet_lora.py @@ -0,0 +1,593 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders.single_file_model import FromOriginalModelMixin +from ...utils import BaseOutput, logging +from ..attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) +from ..embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin +from ..unets.unet_2d_blocks import ( + CrossAttnDownBlock2D, + DownBlock2D, + UNetMidBlock2D, + UNetMidBlock2DCrossAttn, + get_down_block, +) +from ..unets.unet_2d_condition import UNet2DConditionModel +from .controlnet import ControlNetConditioningEmbedding, ControlNetModel, zero_module + +from ...models.lora import LoRACompatibleConv, LoRACompatibleLinear, LoRAConv2dLayer, LoRALinearLayer +from ...loaders.single_file_utils import load_single_file_checkpoint, convert_controlnet_checkpoint, convert_control_lora_checkpoint + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class ControlNetLoRAOutput(BaseOutput): + """ + The output of [`ControlNetLoRAModel`]. + + Args: + down_block_res_samples (`tuple[torch.Tensor]`): + A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should + be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be + used to condition the original UNet's downsampling activations. + mid_down_block_re_sample (`torch.Tensor`): + The activation of the middle block (the lowest sample resolution). Each tensor should be of shape + `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`. + Output can be used to condition the original UNet's middle block activation. + """ + + down_block_res_samples: Tuple[torch.Tensor] + mid_block_res_sample: torch.Tensor + + +class ControlNetLoRAModel(ControlNetModel): + @register_to_config + def __init__( + self, + in_channels: int = 4, + conditioning_channels: int = 3, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str, ...] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int, ...]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + projection_class_embeddings_input_dim: Optional[int] = None, + controlnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), + global_pool_conditions: bool = False, + addition_embed_type_num_heads: int = 64, + ): + super().__init__() + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + # input + conv_in_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + ) + + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + + # control net conditioning embedding + self.controlnet_cond_embedding = ControlNetConditioningEmbedding( + conditioning_embedding_channels=block_out_channels[0], + block_out_channels=conditioning_embedding_out_channels, + conditioning_channels=conditioning_channels, + ) + + self.down_blocks = nn.ModuleList([]) + self.controlnet_down_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[i], + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + downsample_padding=downsample_padding, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + self.down_blocks.append(down_block) + + for _ in range(layers_per_block): + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + if not is_final_block: + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + # mid + mid_block_channel = block_out_channels[-1] + + controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_mid_block = controlnet_block + + if mid_block_type == "UNetMidBlock2DCrossAttn": + self.mid_block = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=mid_block_channel, + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + elif mid_block_type == "UNetMidBlock2D": + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + num_layers=0, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + add_attention=False, + ) + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + for name, module in self.named_modules(): + *parent_module_path, attr_name = name.split(".") + parent_module = self + for path_part in parent_module_path: + parent_module = getattr(parent_module, path_part) + + if isinstance(module, nn.Linear): + module = LinearWithLoRA( + in_features=module.in_features, + out_features=module.out_features, + bias=False if module.bias is None else True, + ) + setattr(parent_module, attr_name, module) + elif isinstance(module, nn.Conv2d): + module = Conv2dWithLoRA( + in_channels=module.in_channels, + out_channels=module.out_channels, + kernel_size=module.kernel_size, + stride=module.stride, + padding=module.padding, + bias=False if module.bias is None else True + ) + setattr(parent_module, attr_name, module) + + @classmethod + def from_unet_and_single_file( + cls, + unet: UNet2DConditionModel, + controlnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), + load_weights_from_unet: bool = True, + conditioning_channels: int = 3, + pretrained_model_link_or_path_or_dict: Optional[str] = None, + ): + r""" + Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`]. + + Parameters: + unet (`UNet2DConditionModel`): + The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied + where applicable. + """ + transformer_layers_per_block = ( + unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1 + ) + encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None + encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None + addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None + addition_time_embed_dim = ( + unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None + ) + + controllora = cls( + encoder_hid_dim=encoder_hid_dim, + encoder_hid_dim_type=encoder_hid_dim_type, + addition_embed_type=addition_embed_type, + addition_time_embed_dim=addition_time_embed_dim, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=unet.config.in_channels, + flip_sin_to_cos=unet.config.flip_sin_to_cos, + freq_shift=unet.config.freq_shift, + down_block_types=unet.config.down_block_types, + only_cross_attention=unet.config.only_cross_attention, + block_out_channels=unet.config.block_out_channels, + layers_per_block=unet.config.layers_per_block, + downsample_padding=unet.config.downsample_padding, + mid_block_scale_factor=unet.config.mid_block_scale_factor, + act_fn=unet.config.act_fn, + norm_num_groups=unet.config.norm_num_groups, + norm_eps=unet.config.norm_eps, + cross_attention_dim=unet.config.cross_attention_dim, + attention_head_dim=unet.config.attention_head_dim, + num_attention_heads=unet.config.num_attention_heads, + use_linear_projection=unet.config.use_linear_projection, + class_embed_type=unet.config.class_embed_type, + num_class_embeds=unet.config.num_class_embeds, + upcast_attention=unet.config.upcast_attention, + resnet_time_scale_shift=unet.config.resnet_time_scale_shift, + projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, + mid_block_type=unet.config.mid_block_type, + controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, + conditioning_embedding_out_channels=conditioning_embedding_out_channels, + conditioning_channels=conditioning_channels, + ) + + if load_weights_from_unet: + controllora.conv_in.load_state_dict(unet.conv_in.state_dict()) + controllora.time_proj.load_state_dict(unet.time_proj.state_dict()) + controllora.time_embedding.load_state_dict(unet.time_embedding.state_dict()) + + if controllora.class_embedding: + controllora.class_embedding.load_state_dict(unet.class_embedding.state_dict()) + + if hasattr(controllora, "add_embedding"): + controllora.add_embedding.load_state_dict(unet.add_embedding.state_dict()) + + controllora.down_blocks.load_state_dict(unet.down_blocks.state_dict()) + controllora.mid_block.load_state_dict(unet.mid_block.state_dict()) + + if isinstance(pretrained_model_link_or_path_or_dict, dict): + checkpoint = pretrained_model_link_or_path_or_dict + elif isinstance(pretrained_model_link_or_path_or_dict, str): + checkpoint = load_single_file_checkpoint( + pretrained_model_link_or_path_or_dict, + ) + else: + raise ValueError + + config = ControlNetModel.load_config("xinsir/controlnet-canny-sdxl-1.0") + checkpoint = convert_control_lora_checkpoint(checkpoint, config) + + for name, param in checkpoint.items(): + *parent_module_path, attr_name = name.split(".") + parent_module = controllora + for path_part in parent_module_path: + parent_module = getattr(parent_module, path_part) + + if getattr(parent_module, attr_name, None) is None: + setattr(parent_module, attr_name, param.to(controllora.device)) + missing, unexpected = controllora.load_state_dict(checkpoint, strict=False) + # print("missing: ", missing) + # print("unexpected: ", unexpected) + + return controllora + + +class LinearWithLoRA(nn.Module): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter( + torch.empty((out_features, in_features), **factory_kwargs) + ) + if bias: + self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + self.up = None + self.down = None + + def reset_parameters(self) -> None: + # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with + # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see + # https://github.com/pytorch/pytorch/issues/57109 + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(self.bias, -bound, bound) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self.up is not None and self.down is not None: + weight = self.weight + torch.mm(self.up.to(self.weight.device), self.down.to(self.weight.device)) + return F.linear(input, weight, self.bias) + else: + return F.linear(input, self.weight, self.bias) + + def extra_repr(self) -> str: + return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}, up={self.up is not None}, down={self.down is not None}" + + +class Conv2dWithLoRA(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]] = 1, + padding: Union[str, Union[int, Tuple[int, int]]] = 0, + dilation: Union[int, Tuple[int, int]] = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", # TODO: refine this type + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.padding_mode = padding_mode + + self.weight = nn.Parameter( + torch.empty( + (out_channels, in_channels // groups, *kernel_size), + **factory_kwargs, + ) + ) + if bias: + self.bias = nn.Parameter(torch.empty(out_channels, **factory_kwargs)) + else: + self.register_parameter("bias", None) + + self.reset_parameters() + + self.up = None + self.down = None + + def reset_parameters(self) -> None: + # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with + # uniform(-1/sqrt(k), 1/sqrt(k)), where k = weight.size(1) * prod(*kernel_size) + # For more details see: https://github.com/pytorch/pytorch/issues/15314#issuecomment-477448573 + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + if fan_in != 0: + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(self.bias, -bound, bound) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self.up is not None and self.down is not None: + weight = self.weight + torch.mm(self.up.flatten(1).to(self.weight.device), self.down.flatten(1).to(self.weight.device)).reshape(self.weight.shape) + return F.conv2d( + input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups + ) + else: + return F.conv2d( + input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups + ) + + def extra_repr(self): + s = ( + "{in_channels}, {out_channels}, kernel_size={kernel_size}" + ", stride={stride}" + ) + if self.padding != (0,) * len(self.padding): + s += ", padding={padding}" + if self.bias is None: + s += ", bias=False" + if self.up is not None: + s += ", up=True" + if self.down is not None: + s += ", down=True" + return s.format(**self.__dict__) + + def __setstate__(self, state): + super().__setstate__(state) + if not hasattr(self, "padding_mode"): + self.padding_mode = "zeros" + + +if __name__ == "__main__": + pass diff --git a/src/diffusers/pipelines/control_lora/__init__.py b/src/diffusers/pipelines/control_lora/__init__.py new file mode 100644 index 000000000000..caeac8ac61ec --- /dev/null +++ b/src/diffusers/pipelines/control_lora/__init__.py @@ -0,0 +1,66 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_flax_available, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_control_lora_sd_xl"] = ["StableDiffusionXLControlLoRAPipeline"] +try: + if not (is_transformers_available() and is_flax_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_flax_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects)) +else: + pass # _import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_control_lora_sd_xl import StableDiffusionXLControlNetPipeline + + try: + if not (is_transformers_available() and is_flax_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_flax_and_transformers_objects import * # noqa F403 + else: + pass # from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline + + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/control_lora/pipeline_control_lora_sd_xl.py b/src/diffusers/pipelines/control_lora/pipeline_control_lora_sd_xl.py new file mode 100644 index 000000000000..033933029907 --- /dev/null +++ b/src/diffusers/pipelines/control_lora/pipeline_control_lora_sd_xl.py @@ -0,0 +1,1272 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from diffusers.utils.import_utils import is_invisible_watermark_available + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.attention_processor import ( + AttnProcessor2_0, + XFormersAttnProcessor, +) +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput + + +if is_invisible_watermark_available(): + from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + + +from ...utils import is_torch_xla_available + +from ...models.controlnets.controlnet_lora import ControlNetLoRAModel + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate + >>> from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, AutoencoderKL + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> import cv2 + >>> from PIL import Image + + >>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" + >>> negative_prompt = "low quality, bad quality, sketches" + + >>> # download an image + >>> image = load_image( + ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png" + ... ) + + >>> # initialize the models and pipeline + >>> controlnet_conditioning_scale = 0.5 # recommended for good generalization + >>> controlnet = ControlNetModel.from_pretrained( + ... "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16 + ... ) + >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) + >>> pipe = StableDiffusionXLControlNetPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, vae=vae, torch_dtype=torch.float16 + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> # get canny image + >>> image = np.array(image) + >>> image = cv2.Canny(image, 100, 200) + >>> image = image[:, :, None] + >>> image = np.concatenate([image, image, image], axis=2) + >>> canny_image = Image.fromarray(image) + + >>> # generate image + >>> image = pipe( + ... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionXLControlLoRAPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + FromSingleFileMixin, +): + r""" + """ + + # leave controlnet out on purpose because it iterates with unet + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "negative_add_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: ControlNetLoRAModel, + scheduler: KarrasDiffusionSchedulers, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + feature_extractor: CLIPImageProcessor = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + prompt_2, + image, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # Check `image` and `controlnet_conditioning_scale` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.unet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetLoRAModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetLoRAModel) + ): + self.check_image(image, prompt, prompt_embeds) + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + else: + assert False + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + # @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: float = 1.0, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + original_size: Tuple[int, int] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + control_guidance_start, control_guidance_end = ( + [control_guidance_start], + [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + image, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + global_pool_conditions = controlnet.config.global_pool_conditions + guess_mode = guess_mode or global_pool_conditions + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt, + prompt_2, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # 4. Prepare image + if isinstance(controlnet, ControlNetLoRAModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = image.shape[-2:] + else: + assert False + + # 5. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6.5 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetLoRAModel) else keeps) + + # 7.2 Prepare added time ids & embeddings + if isinstance(image, list): + original_size = original_size or image[0].shape[-2:] + else: + original_size = original_size or image.shape[-2:] + target_size = target_size or (height, width) + + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # 8.1 Apply denoising_end + if ( + self.denoising_end is not None + and isinstance(self.denoising_end, float) + and self.denoising_end > 0 + and self.denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + is_unet_compiled = is_compiled_module(self.unet) + is_controlnet_compiled = is_compiled_module(self.controlnet) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + torch._inductor.cudagraph_mark_step_begin() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + # control-lora inference + if guess_mode and self.do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + controlnet_added_cond_kwargs = { + "text_embeds": add_text_embeds.chunk(2)[1], + "time_ids": add_time_ids.chunk(2)[1], + } + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = added_cond_kwargs + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + ) + + if guess_mode and self.do_classifier_free_guidance: + # Inferred ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) + image = callback_outputs.pop("image", image) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) + + +if __name__ == "__main__": + from diffusers import ( + StableDiffusionXLControlNetPipeline, + ControlNetModel, + UNet2DConditionModel, + ) + import torch + + pipe_id = "stabilityai/stable-diffusion-xl-base-1.0" + controlnet_id = "xinsir/controlnet-canny-sdxl-1.0" + lora_id = "stabilityai/control-lora" + lora_filename = "control-LoRAs-rank128/control-lora-canny-rank128.safetensors" + + + unet = UNet2DConditionModel.from_pretrained(pipe_id, subfolder="unet", torch_dtype=torch.float16).to("cuda") + controlnet = ControlNetLoRAModel.from_unet_and_single_file(unet, pretrained_model_link_or_path_or_dict="https://huggingface.co/stabilityai/control-lora/control-LoRAs-rank128/control-lora-canny-rank128.safetensors").to("cuda").to(torch.float16) + + from diffusers import AutoencoderKL + from diffusers.utils import load_image, make_image_grid + from PIL import Image + import numpy as np + import cv2 + + prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" + negative_prompt = "low quality, bad quality, sketches" + + image = load_image("https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png") + + controlnet_conditioning_scale = 1.0 # recommended for good generalization + + vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) + pipe = StableDiffusionXLControlLoRAPipeline.from_pretrained( + pipe_id, + unet=unet, + controlnet=controlnet, + vae=vae, + torch_dtype=torch.float16, + safety_checker=None, + ).to("cuda") + + image = np.array(image) + image = cv2.Canny(image, 100, 200) + image = image[:, :, None] + image = np.concatenate([image, image, image], axis=2) + image = Image.fromarray(image) + + images = pipe( + prompt, negative_prompt=negative_prompt, image=image, + controlnet_conditioning_scale=controlnet_conditioning_scale, + num_images_per_prompt=4 + ).images + + final_image = [image] + images + grid = make_image_grid(final_image, 1, 5) + grid.save(f"hf-logo1.png") From e9d91e156d5d6d19dbf3c3a2d42cdb1c88659eac Mon Sep 17 00:00:00 2001 From: lavinal712 Date: Sat, 1 Feb 2025 01:29:19 +0000 Subject: [PATCH 02/19] cannot load lora adapter --- src/diffusers/loaders/__init__.py | 2 + src/diffusers/loaders/controlnet.py | 341 ++++++++++++++++++ .../loaders/lora_conversion_utils.py | 153 ++++++++ .../models/controlnets/controlnet.py | 3 +- .../pipelines/control_lora/control_lora.py | 58 +++ 5 files changed, 556 insertions(+), 1 deletion(-) create mode 100644 src/diffusers/loaders/controlnet.py create mode 100644 src/diffusers/pipelines/control_lora/control_lora.py diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 2db8b53db498..f79c33af03ec 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -55,6 +55,7 @@ def text_encoder_attn_modules(text_encoder): if is_torch_available(): _import_structure["single_file_model"] = ["FromOriginalModelMixin"] + _import_structure["controlnet"] = ["ControlNetLoadersMixin"] _import_structure["transformer_flux"] = ["FluxTransformer2DLoadersMixin"] _import_structure["transformer_sd3"] = ["SD3Transformer2DLoadersMixin"] _import_structure["unet"] = ["UNet2DConditionLoadersMixin"] @@ -87,6 +88,7 @@ def text_encoder_attn_modules(text_encoder): if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if is_torch_available(): from .single_file_model import FromOriginalModelMixin + from .controlnet import ControlNetLoadersMixin from .transformer_flux import FluxTransformer2DLoadersMixin from .transformer_sd3 import SD3Transformer2DLoadersMixin from .unet import UNet2DConditionLoadersMixin diff --git a/src/diffusers/loaders/controlnet.py b/src/diffusers/loaders/controlnet.py new file mode 100644 index 000000000000..128ef6670796 --- /dev/null +++ b/src/diffusers/loaders/controlnet.py @@ -0,0 +1,341 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from contextlib import nullcontext +from typing import Dict, List, Optional, Union + +import torch +from huggingface_hub.utils import validate_hf_hub_args + +from ..loaders.single_file_utils import DIFFUSERS_TO_LDM_MAPPING +from ..models.embeddings import ( + ImageProjection, + MultiIPAdapterImageProjection, +) +from ..models.modeling_utils import load_model_dict_into_meta +from ..utils import ( + USE_PEFT_BACKEND, + deprecate, + get_submodule_by_name, + is_accelerate_available, + is_peft_available, + is_peft_version, + is_torch_version, + is_transformers_available, + is_transformers_version, + logging, +) +from .lora_base import ( # noqa + LORA_WEIGHT_NAME, + LORA_WEIGHT_NAME_SAFE, + LoraBaseMixin, + _fetch_state_dict, + _load_lora_into_text_encoder, +) +from .lora_conversion_utils import ( + _convert_bfl_flux_control_lora_to_diffusers, + _convert_hunyuan_video_lora_to_diffusers, + _convert_kohya_flux_lora_to_diffusers, + _convert_non_diffusers_lora_to_diffusers, + _convert_stabilityai_control_lora_to_diffusers, + _convert_xlabs_flux_lora_to_diffusers, + _maybe_map_sgm_blocks_to_diffusers, +) + + +_LOW_CPU_MEM_USAGE_DEFAULT_LORA = False +if is_torch_version(">=", "1.9.0"): + if ( + is_peft_available() + and is_peft_version(">=", "0.13.1") + and is_transformers_available() + and is_transformers_version(">", "4.45.2") + ): + _LOW_CPU_MEM_USAGE_DEFAULT_LORA = True + +logger = logging.get_logger(__name__) + +CONTROLNET_NAME = "controlnet" + + +class ControlNetLoadersMixin: + """ + Load layers into a [`ControlNetModel`]. + """ + + _lora_loadable_modules = ["controlnet"] + controlnet_name = CONTROLNET_NAME + _control_lora_supported_norm_keys = ["norm1", "norm2", "norm3"] + + @classmethod + @validate_hf_hub_args + def lora_state_dict( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + return_alphas: bool = False, + **kwargs): + r""" + """ + + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + is_stabilityai = "lora_controlnet" in state_dict and "input_blocks.11.0.in_layers.0.weight" not in state_dict + if is_stabilityai: + state_dict = _convert_stabilityai_control_lora_to_diffusers(state_dict) + return (state_dict, None) if return_alphas else state_dict + + raise ValueError + + def load_lora_weights( + self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + ): + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + state_dict, network_alphas = self.lora_state_dict( + pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs + ) + + has_lora_keys = any("lora" in key for key in state_dict.keys()) + + # Control LoRAs also have norm keys + has_norm_keys = any( + norm_key in key for key in state_dict.keys() for norm_key in self._control_lora_supported_norm_keys + ) + + if not (has_lora_keys or has_norm_keys): + raise ValueError("Invalid LoRA checkpoint.") + + controlnet_lora_state_dict = { + k: state_dict.pop(k) for k in list(state_dict.keys()) if "lora" in k + } + controlnet_norm_state_dict = { + k: state_dict.pop(k) + for k in list(state_dict.keys()) + if any(norm_key in k for norm_key in self._control_lora_supported_norm_keys) + } + controlnet_others_state_dict = { + k: state_dict.pop(k) for k in list(state_dict.keys()) + } + + controlnet = self + + if len(controlnet_lora_state_dict) > 0: + self.load_lora_into_controlnet( + controlnet_lora_state_dict, + network_alphas=network_alphas, + controlnet=controlnet, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + + if len(controlnet_norm_state_dict) > 0: + self._load_norm_into_controlnet( + controlnet_norm_state_dict, + controlnet=controlnet, + discard_original_layers=False, + ) + + if len(controlnet_others_state_dict) > 0: + self._load_others_into_controlnet( + controlnet_others_state_dict, + controlnet=controlnet, + discard_original_layers=False, + ) + + @classmethod + def load_lora_into_controlnet( + cls, state_dict, network_alphas, controlnet, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False + ): + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # Load the layers corresponding to controlnet. + logger.info(f"Loading {cls}.") + controlnet.load_lora_adapter( + state_dict, + network_alphas=network_alphas, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + + @classmethod + def _load_norm_into_controlnet( + cls, + state_dict, + controlnet, + prefix=None, + discard_original_layers=False, + ) -> Dict[str, torch.Tensor]: + # Remove prefix if present + prefix = prefix or cls.controlnet_name + for key in list(state_dict.keys()): + if key.split(".")[0] == prefix: + state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key) + + # Find invalid keys + controlnet_state_dict = controlnet.state_dict() + controlnet_keys = set(controlnet_state_dict.keys()) + state_dict_keys = set(state_dict.keys()) + extra_keys = list(state_dict_keys - controlnet_keys) + + if extra_keys: + logger.warning( + f"Unsupported keys found in state dict when trying to load normalization layers into the controlnet. The following keys will be ignored:\n{extra_keys}." + ) + + for key in extra_keys: + state_dict.pop(key) + + # Save the layers that are going to be overwritten so that unload_lora_weights can work as expected + overwritten_layers_state_dict = {} + if not discard_original_layers: + for key in state_dict.keys(): + overwritten_layers_state_dict[key] = controlnet_state_dict[key].clone() + + logger.info( + "The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will directly update the state_dict of the controlnet " + 'as opposed to the LoRA layers that will co-exist separately until the "fuse_lora()" method is called. That is to say, the normalization layers will always be directly ' + "fused into the controlnet and can only be unfused if `discard_original_layers=True` is passed. This might also have implications when dealing with multiple LoRAs. " + "If you notice something unexpected, please open an issue: https://github.com/huggingface/diffusers/issues." + ) + + # We can't load with strict=True because the current state_dict does not contain all the controlnet keys + incompatible_keys = controlnet.load_state_dict(state_dict, strict=False) + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + + # We shouldn't expect to see the supported norm keys here being present in the unexpected keys. + if unexpected_keys: + if any(norm_key in k for k in unexpected_keys for norm_key in cls._control_lora_supported_norm_keys): + raise ValueError( + f"Found {unexpected_keys} as unexpected keys while trying to load norm layers into the controlnet." + ) + + return overwritten_layers_state_dict + + @classmethod + def _load_others_into_controlnet( + cls, + state_dict, + controlnet, + prefix=None, + discard_original_layers=False, + ) -> Dict[str, torch.Tensor]: + # Remove prefix if present + prefix = prefix or cls.controlnet_name + for key in list(state_dict.keys()): + if key.split(".")[0] == prefix: + state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key) + + # Find invalid keys + controlnet_state_dict = controlnet.state_dict() + controlnet_keys = set(controlnet_state_dict.keys()) + state_dict_keys = set(state_dict.keys()) + extra_keys = list(state_dict_keys - controlnet_keys) + + if extra_keys: + logger.warning( + f"Unsupported keys found in state dict when trying to load normalization layers into the controlnet. The following keys will be ignored:\n{extra_keys}." + ) + + for key in extra_keys: + state_dict.pop(key) + + # Save the layers that are going to be overwritten so that unload_lora_weights can work as expected + overwritten_layers_state_dict = {} + if not discard_original_layers: + for key in state_dict.keys(): + overwritten_layers_state_dict[key] = controlnet_state_dict[key].clone() + + logger.info( + "The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will directly update the state_dict of the transformer " + 'as opposed to the LoRA layers that will co-exist separately until the "fuse_lora()" method is called. That is to say, the normalization layers will always be directly ' + "fused into the controlnet and can only be unfused if `discard_original_layers=True` is passed. This might also have implications when dealing with multiple LoRAs. " + "If you notice something unexpected, please open an issue: https://github.com/huggingface/diffusers/issues." + ) + + # We can't load with strict=True because the current state_dict does not contain all the transformer keys + incompatible_keys = controlnet.load_state_dict(state_dict, strict=False) + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + + # We shouldn't expect to see the supported norm keys here being present in the unexpected keys. + if unexpected_keys: + if any(norm_key in k for k in unexpected_keys for norm_key in cls._control_lora_supported_norm_keys): + raise ValueError( + f"Found {unexpected_keys} as unexpected keys while trying to load norm layers into the transformer." + ) + + return overwritten_layers_state_dict + + def fuse_lora( + self, + components: List[str] = ["controlnet"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + super().fuse_lora( + components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names + ) + + def unfuse_lora(self, components: List[str] = ["controlnet"], **kwargs): + super().unfuse_lora(components=components) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index e064aeba43b6..e297d5063026 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -17,6 +17,12 @@ import torch from ..utils import is_peft_version, logging +from .single_file_utils import ( + DIFFUSERS_TO_LDM_MAPPING, + LDM_CONTROLNET_KEY, + update_unet_resnet_ldm_to_diffusers, + update_unet_attention_ldm_to_diffusers +) logger = logging.get_logger(__name__) @@ -1148,3 +1154,150 @@ def remap_single_transformer_blocks_(key, state_dict): converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key) return converted_state_dict + +def _convert_stabilityai_control_lora_to_diffusers(checkpoint): + # Return checkpoint if it's already been converted + if "time_embedding.linear_1.weight" in checkpoint: + return checkpoint + # Some controlnet ckpt files are distributed independently from the rest of the + # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ + if "time_embed.0.weight" in checkpoint: + controlnet_state_dict = checkpoint + + else: + controlnet_state_dict = {} + keys = list(checkpoint.keys()) + controlnet_key = LDM_CONTROLNET_KEY + for key in keys: + if key.startswith(controlnet_key): + controlnet_state_dict[key.replace(controlnet_key, "")] = checkpoint.get(key) + else: + controlnet_state_dict[key] = checkpoint.get(key) + + new_checkpoint = {} + ldm_controlnet_keys = DIFFUSERS_TO_LDM_MAPPING["controlnet"]["layers"] + for diffusers_key, ldm_key in ldm_controlnet_keys.items(): + if ldm_key not in controlnet_state_dict: + continue + new_checkpoint[diffusers_key] = controlnet_state_dict[ldm_key] + for k, v in controlnet_state_dict.items(): + if "time_embed.0" in k: + new_checkpoint[k.replace("time_embed.0", "time_embedding.linear_1")] = v + elif "time_embed.2" in k: + new_checkpoint[k.replace("time_embed.2", "time_embedding.linear_2")] = v + elif "input_blocks.0.0" in k: + new_checkpoint[k.replace("input_blocks.0.0", "conv_in")] = v + elif "label_emb.0.0" in k: + new_checkpoint[k.replace("label_emb.0.0", "add_embedding.linear_1")] = v + elif "label_emb.0.2" in k: + new_checkpoint[k.replace("label_emb.0.2", "add_embedding.linear_2")] = v + elif "input_blocks.3.0.op" in k: + new_checkpoint[k.replace("input_blocks.3.0.op", "down_blocks.0.downsamplers.0.conv")] = v + elif "input_blocks.6.0.op" in k: + new_checkpoint[k.replace("input_blocks.6.0.op", "down_blocks.1.downsamplers.0.conv")] = v + + # Retrieves the keys for the input blocks only + num_input_blocks = len( + {".".join(layer.split(".")[:2]) for layer in controlnet_state_dict if "input_blocks" in layer} + ) + input_blocks = { + layer_id: [key for key in controlnet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Down blocks + for i in range(1, num_input_blocks): + block_id = (i - 1) // (2 + 1) + layer_in_block_id = (i - 1) % (2 + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + update_unet_resnet_ldm_to_diffusers( + resnets, + new_checkpoint, + controlnet_state_dict, + {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}, + ) + + if f"input_blocks.{i}.0.op.weight" in controlnet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = controlnet_state_dict.get( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = controlnet_state_dict.get( + f"input_blocks.{i}.0.op.bias" + ) + + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + if attentions: + update_unet_attention_ldm_to_diffusers( + attentions, + new_checkpoint, + controlnet_state_dict, + {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}, + ) + + # controlnet down blocks + for i in range(num_input_blocks): + new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = controlnet_state_dict.get(f"zero_convs.{i}.0.weight") + new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = controlnet_state_dict.get(f"zero_convs.{i}.0.bias") + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len( + {".".join(layer.split(".")[:2]) for layer in controlnet_state_dict if "middle_block" in layer} + ) + middle_blocks = { + layer_id: [key for key in controlnet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Mid blocks + for key in middle_blocks.keys(): + diffusers_key = max(key - 1, 0) + if key % 2 == 0: + update_unet_resnet_ldm_to_diffusers( + middle_blocks[key], + new_checkpoint, + controlnet_state_dict, + mapping={"old": f"middle_block.{key}", "new": f"mid_block.resnets.{diffusers_key}"}, + ) + else: + update_unet_attention_ldm_to_diffusers( + middle_blocks[key], + new_checkpoint, + controlnet_state_dict, + mapping={"old": f"middle_block.{key}", "new": f"mid_block.attentions.{diffusers_key}"}, + ) + + # mid block + new_checkpoint["controlnet_mid_block.weight"] = controlnet_state_dict.get("middle_block_out.0.weight") + new_checkpoint["controlnet_mid_block.bias"] = controlnet_state_dict.get("middle_block_out.0.bias") + + # controlnet cond embedding blocks + cond_embedding_blocks = { + ".".join(layer.split(".")[:2]) + for layer in controlnet_state_dict + if "input_hint_block" in layer and ("input_hint_block.0" not in layer) and ("input_hint_block.14" not in layer) + } + num_cond_embedding_blocks = len(cond_embedding_blocks) + + for idx in range(1, num_cond_embedding_blocks + 1): + diffusers_idx = idx - 1 + cond_block_id = 2 * idx + + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.weight"] = controlnet_state_dict.get( + f"input_hint_block.{cond_block_id}.weight" + ) + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.bias"] = controlnet_state_dict.get( + f"input_hint_block.{cond_block_id}.bias" + ) + + new_new_checkpoint = {} + for k, v in controlnet_state_dict.items(): + if ".down" in k: + new_new_checkpoint[k.replace(".down", "lora_A.weight")] = v + elif ".up" in k: + new_new_checkpoint[k.replace(".up", "lora_B.weight")] = v + new_new_checkpoint[k] = v + + return new_new_checkpoint diff --git a/src/diffusers/models/controlnets/controlnet.py b/src/diffusers/models/controlnets/controlnet.py index 1453aaf4362c..25f9f8a5ce7e 100644 --- a/src/diffusers/models/controlnets/controlnet.py +++ b/src/diffusers/models/controlnets/controlnet.py @@ -19,6 +19,7 @@ from torch.nn import functional as F from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin, ControlNetLoadersMixin from ...loaders.single_file_model import FromOriginalModelMixin from ...utils import BaseOutput, logging from ..attention_processor import ( @@ -108,7 +109,7 @@ def forward(self, conditioning): return embedding -class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): +class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, ControlNetLoadersMixin, PeftAdapterMixin): """ A ControlNet model. diff --git a/src/diffusers/pipelines/control_lora/control_lora.py b/src/diffusers/pipelines/control_lora/control_lora.py new file mode 100644 index 000000000000..b7e8de7da1ac --- /dev/null +++ b/src/diffusers/pipelines/control_lora/control_lora.py @@ -0,0 +1,58 @@ + + +if __name__ == "__main__": + from diffusers import ( + StableDiffusionXLControlNetPipeline, + ControlNetModel, + UNet2DConditionModel, + ) + import torch + + pipe_id = "stabilityai/stable-diffusion-xl-base-1.0" + controlnet_id = "xinsir/controlnet-canny-sdxl-1.0" + lora_id = "stabilityai/control-lora" + lora_filename = "control-LoRAs-rank128/control-lora-canny-rank128.safetensors" + + + unet = UNet2DConditionModel.from_pretrained(pipe_id, subfolder="unet", torch_dtype=torch.float16).to("cuda") + controlnet = ControlNetModel.from_unet(unet).to(device="cuda", dtype=torch.float16) + controlnet.load_lora_weights(lora_id, weight_name=lora_filename, controlnet_config=controlnet.config) + + from diffusers import AutoencoderKL + from diffusers.utils import load_image, make_image_grid + from PIL import Image + import numpy as np + import cv2 + + prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" + negative_prompt = "low quality, bad quality, sketches" + + image = load_image("https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png") + + controlnet_conditioning_scale = 1.0 # recommended for good generalization + + vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) + pipe = StableDiffusionXLControlNetPipeline.from_pretrained( + pipe_id, + unet=unet, + controlnet=controlnet, + vae=vae, + torch_dtype=torch.float16, + safety_checker=None, + ).to("cuda") + + image = np.array(image) + image = cv2.Canny(image, 100, 200) + image = image[:, :, None] + image = np.concatenate([image, image, image], axis=2) + image = Image.fromarray(image) + + images = pipe( + prompt, negative_prompt=negative_prompt, image=image, + controlnet_conditioning_scale=controlnet_conditioning_scale, + num_images_per_prompt=4 + ).images + + final_image = [image] + images + grid = make_image_grid(final_image, 1, 5) + grid.save(f"hf-logo1.png") From 9cf8ad7a73f86586c73a0b0b4cf0aeb2fd238ea5 Mon Sep 17 00:00:00 2001 From: lavinal712 Date: Tue, 4 Feb 2025 16:40:59 +0000 Subject: [PATCH 03/19] test --- src/diffusers/pipelines/control_lora/control_lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/control_lora/control_lora.py b/src/diffusers/pipelines/control_lora/control_lora.py index b7e8de7da1ac..c1a7512e62d0 100644 --- a/src/diffusers/pipelines/control_lora/control_lora.py +++ b/src/diffusers/pipelines/control_lora/control_lora.py @@ -16,7 +16,7 @@ unet = UNet2DConditionModel.from_pretrained(pipe_id, subfolder="unet", torch_dtype=torch.float16).to("cuda") controlnet = ControlNetModel.from_unet(unet).to(device="cuda", dtype=torch.float16) - controlnet.load_lora_weights(lora_id, weight_name=lora_filename, controlnet_config=controlnet.config) + controlnet.load_lora_adapter(lora_id, weight_name=lora_filename, controlnet_config=controlnet.config) from diffusers import AutoencoderKL from diffusers.utils import load_image, make_image_grid From 2453e149d2bc6d0f94d5b7f502c6b3a53fa50edf Mon Sep 17 00:00:00 2001 From: lavinal712 Date: Fri, 7 Feb 2025 10:24:24 +0000 Subject: [PATCH 04/19] 1 --- src/diffusers/loaders/peft.py | 10 +- .../models/controlnets/controlnet.py | 4 +- src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/state_dict_utils.py | 137 ++++++++++++++++++ 4 files changed, 149 insertions(+), 3 deletions(-) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 0d26738eec62..659bb9711564 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -21,10 +21,12 @@ import safetensors import torch +from ..loaders.lora_conversion_utils import _convert_stabilityai_control_lora_to_diffusers from ..utils import ( MIN_PEFT_VERSION, USE_PEFT_BACKEND, check_peft_version, + convert_control_lora_state_dict_to_peft, convert_unet_state_dict_to_peft, delete_adapter_layers, get_adapter_name, @@ -243,7 +245,13 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans # check with first key if is not in peft format first_key = next(iter(state_dict.keys())) if "lora_A" not in first_key: - state_dict = convert_unet_state_dict_to_peft(state_dict) + if "lora_controlnet" in state_dict: + del state_dict["lora_controlnet"] + state_dict = convert_control_lora_state_dict_to_peft(state_dict) + else: + state_dict = convert_unet_state_dict_to_peft(state_dict) + print(state_dict.keys()) + print(len(state_dict.keys())) rank = {} for key, val in state_dict.items(): diff --git a/src/diffusers/models/controlnets/controlnet.py b/src/diffusers/models/controlnets/controlnet.py index 25f9f8a5ce7e..f3784d51d3bd 100644 --- a/src/diffusers/models/controlnets/controlnet.py +++ b/src/diffusers/models/controlnets/controlnet.py @@ -19,7 +19,7 @@ from torch.nn import functional as F from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import PeftAdapterMixin, ControlNetLoadersMixin +from ...loaders import PeftAdapterMixin from ...loaders.single_file_model import FromOriginalModelMixin from ...utils import BaseOutput, logging from ..attention_processor import ( @@ -109,7 +109,7 @@ def forward(self, conditioning): return embedding -class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, ControlNetLoadersMixin, PeftAdapterMixin): +class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin): """ A ControlNet model. diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index d82aded4c435..5582391378c0 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -122,6 +122,7 @@ convert_state_dict_to_kohya, convert_state_dict_to_peft, convert_unet_state_dict_to_peft, + convert_control_lora_state_dict_to_peft, ) diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py index 62b114ba67e3..ed353ee6d7f9 100644 --- a/src/diffusers/utils/state_dict_utils.py +++ b/src/diffusers/utils/state_dict_utils.py @@ -50,6 +50,18 @@ class StateDictType(enum.Enum): ".to_out.lora_magnitude_vector": ".to_out.0.lora_magnitude_vector", } +CONTROL_LORA_TO_DIFFUSERS = { + ".to_out.up": ".to_out.0.lora_B", + ".to_out.down": ".to_out.0.lora_A", + ".to_q.down": ".to_q.lora_A", + ".to_q.up": ".to_q.lora_B", + ".to_k.down": ".to_k.lora_A", + ".to_k.up": ".to_k.lora_B", + ".to_v.down": ".to_v.lora_A", + ".to_v.up": ".to_v.lora_B", + ".down": ".lora_A", + ".up": ".lora_B", +} DIFFUSERS_TO_PEFT = { ".q_proj.lora_linear_layer.up": ".q_proj.lora_B", @@ -253,6 +265,131 @@ def convert_unet_state_dict_to_peft(state_dict): return convert_state_dict(state_dict, mapping) +def convert_control_lora_state_dict_to_peft(state_dict): + def _convert_controlnet_to_diffusers(state_dict): + is_sdxl = "input_blocks.11.0.in_layers.0.weight" not in state_dict + logger.info(f"Using ControlNet lora ({'SDXL' if is_sdxl else 'SD15'})") + + # Retrieves the keys for the input blocks only + num_input_blocks = len( + {".".join(layer.split(".")[:2]) for layer in state_dict if "input_blocks" in layer} + ) + input_blocks = { + layer_id: [key for key in state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + layers_per_block = 2 + + converted_state_dict = {} + # Conv in layers + for key in input_blocks[0]: + diffusers_key = key.replace("conv_in", "input_blocks.0.0") + converted_state_dict[diffusers_key] = state_dict.get(key) + + # Down blocks + for i in range(1, num_input_blocks): + block_id = (i - 1) // (layers_per_block + 1) + layer_in_block_id = (i - 1) % (layers_per_block + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + for key in resnets: + diffusers_key = (key.replace("in_layers.0", "norm1") + .replace("in_layers.2", "conv1") + .replace("out_layers.0", "norm2") + .replace("out_layers.3", "conv2") + .replace("emb_layers.1", "time_emb_proj") + .replace("skip_connection", "conv_shortcut") + ) + diffusers_key = diffusers_key.replace( + f"input_blocks.{i}.0", f"down_blocks.{block_id}.resnets.{layer_in_block_id}" + ) + converted_state_dict[diffusers_key] = state_dict.get(key) + + if f"input_blocks.{i}.0.op.weight" in state_dict: + converted_state_dict[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = state_dict.get( + f"input_blocks.{i}.0.op.weight" + ) + converted_state_dict[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = state_dict.get( + f"input_blocks.{i}.0.op.bias" + ) + + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + if attentions: + for key in attentions: + diffusers_key = key.replace( + f"input_blocks.{i}.1", f"down_blocks.{block_id}.attentions.{layer_in_block_id}" + ) + converted_state_dict[diffusers_key] = state_dict.get(key) + + # controlnet down blocks + for i in range(num_input_blocks): + converted_state_dict[f"controlnet_down_blocks.{i}.weight"] = state_dict.get(f"zero_convs.{i}.0.weight") + converted_state_dict[f"controlnet_down_blocks.{i}.bias"] = state_dict.get(f"zero_convs.{i}.0.bias") + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len( + {".".join(layer.split(".")[:2]) for layer in state_dict if "middle_block" in layer} + ) + middle_blocks = { + layer_id: [key for key in state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Mid blocks + for key in middle_blocks.keys(): + diffusers_key = max(key - 1, 0) + if key % 2 == 0: + for k in middle_blocks[key]: + diffusers_key = (k.replace("in_layers.0", "norm1") + .replace("in_layers.2", "conv1") + .replace("out_layers.0", "norm2") + .replace("out_layers.3", "conv2") + .replace("emb_layers.1", "time_emb_proj") + .replace("skip_connection", "conv_shortcut") + ) + diffusers_key = diffusers_key.replace( + f"middle_block.{k}", f"mid_block.resnets.{diffusers_key}" + ) + converted_state_dict[diffusers_key] = state_dict.get(k) + else: + for k in middle_blocks[key]: + diffusers_key = k.replace( + f"middle_block.{k}", f"mid_block.attentions.{diffusers_key}" + ) + converted_state_dict[diffusers_key] = state_dict.get(k) + + # mid block + converted_state_dict["controlnet_mid_block.weight"] = state_dict.get("middle_block_out.0.weight") + converted_state_dict["controlnet_mid_block.bias"] = state_dict.get("middle_block_out.0.bias") + + # controlnet cond embedding blocks + cond_embedding_blocks = { + ".".join(layer.split(".")[:2]) + for layer in state_dict + if "input_hint_block" in layer and ("input_hint_block.0" not in layer) and ("input_hint_block.14" not in layer) + } + num_cond_embedding_blocks = len(cond_embedding_blocks) + + for idx in range(1, num_cond_embedding_blocks + 1): + diffusers_idx = idx - 1 + cond_block_id = 2 * idx + + converted_state_dict[f"controlnet_cond_embedding.blocks.{diffusers_idx}.weight"] = state_dict.get( + f"input_hint_block.{cond_block_id}.weight" + ) + converted_state_dict[f"controlnet_cond_embedding.blocks.{diffusers_idx}.bias"] = state_dict.get( + f"input_hint_block.{cond_block_id}.bias" + ) + + return converted_state_dict + + state_dict = _convert_controlnet_to_diffusers(state_dict) + mapping = CONTROL_LORA_TO_DIFFUSERS + return convert_state_dict(state_dict, mapping) + + def convert_all_state_dict_to_peft(state_dict): r""" Attempts to first `convert_state_dict_to_peft`, and if it doesn't detect `lora_linear_layer` for a valid From 39b3b84acce8884dcd3777acbd992709d708670a Mon Sep 17 00:00:00 2001 From: lavinal712 Date: Fri, 7 Feb 2025 18:43:19 +0000 Subject: [PATCH 05/19] add control-lora --- src/diffusers/loaders/__init__.py | 7 +- src/diffusers/loaders/peft.py | 189 +++++++++++++++++- .../models/controlnets/controlnet.py | 4 +- src/diffusers/utils/state_dict_utils.py | 113 ++++++++--- 4 files changed, 275 insertions(+), 38 deletions(-) diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index f79c33af03ec..3b645307cec4 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -55,7 +55,6 @@ def text_encoder_attn_modules(text_encoder): if is_torch_available(): _import_structure["single_file_model"] = ["FromOriginalModelMixin"] - _import_structure["controlnet"] = ["ControlNetLoadersMixin"] _import_structure["transformer_flux"] = ["FluxTransformer2DLoadersMixin"] _import_structure["transformer_sd3"] = ["SD3Transformer2DLoadersMixin"] _import_structure["unet"] = ["UNet2DConditionLoadersMixin"] @@ -82,17 +81,17 @@ def text_encoder_attn_modules(text_encoder): "SD3IPAdapterMixin", ] -_import_structure["peft"] = ["PeftAdapterMixin"] +_import_structure["peft"] = ["PeftAdapterMixin", "ControlLoRAMixin"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if is_torch_available(): from .single_file_model import FromOriginalModelMixin - from .controlnet import ControlNetLoadersMixin from .transformer_flux import FluxTransformer2DLoadersMixin from .transformer_sd3 import SD3Transformer2DLoadersMixin from .unet import UNet2DConditionLoadersMixin from .utils import AttnProcsLayers + from .peft import ControlLoRAMixin if is_transformers_available(): from .ip_adapter import ( @@ -116,7 +115,7 @@ def text_encoder_attn_modules(text_encoder): from .single_file import FromSingleFileMixin from .textual_inversion import TextualInversionLoaderMixin - from .peft import PeftAdapterMixin + from .peft import PeftAdapterMixin, ControlLoRAMixin else: import sys diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 659bb9711564..a8ee406ae42d 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -245,13 +245,7 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans # check with first key if is not in peft format first_key = next(iter(state_dict.keys())) if "lora_A" not in first_key: - if "lora_controlnet" in state_dict: - del state_dict["lora_controlnet"] - state_dict = convert_control_lora_state_dict_to_peft(state_dict) - else: - state_dict = convert_unet_state_dict_to_peft(state_dict) - print(state_dict.keys()) - print(len(state_dict.keys())) + state_dict = convert_unet_state_dict_to_peft(state_dict) rank = {} for key, val in state_dict.items(): @@ -759,3 +753,184 @@ def delete_adapters(self, adapter_names: Union[List[str], str]): # Pop also the corresponding adapter from the config if hasattr(self, "peft_config"): self.peft_config.pop(adapter_name, None) + + +class ControlLoRAMixin(PeftAdapterMixin): + TARGET_MODULES = ["to_q", "to_k", "to_v", "to_out.0", "ff.net.0.proj", "ff.net.2", "proj_in", "proj_out", + "conv", "conv1", "conv2", "conv_in", "conv_shortcut", "linear_1", "linear_2", "time_emb_proj"] + SAVE_MODULES = ["controlnet_cond_embedding.conv_in", "controlnet_cond_embedding.blocks.0", + "controlnet_cond_embedding.blocks.1", "controlnet_cond_embedding.blocks.2", + "controlnet_cond_embedding.blocks.3", "controlnet_cond_embedding.blocks.4", + "controlnet_cond_embedding.blocks.5", "controlnet_cond_embedding.conv_out", + "controlnet_down_blocks.0", "controlnet_down_blocks.1", "controlnet_down_blocks.2", + "controlnet_down_blocks.3", "controlnet_down_blocks.4", "controlnet_down_blocks.5", + "controlnet_down_blocks.6", "controlnet_down_blocks.7", "controlnet_down_blocks.8", + "controlnet_mid_block", "norm", "norm1", "norm2", "norm3"] + + def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="transformer", **kwargs): + from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict + from peft.tuners.tuners_utils import BaseTunerLayer + + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + adapter_name = kwargs.pop("adapter_name", None) + network_alphas = kwargs.pop("network_alphas", None) + _pipeline = kwargs.pop("_pipeline", None) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False) + allow_pickle = False + + if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + if network_alphas is not None and prefix is None: + raise ValueError("`network_alphas` cannot be None when `prefix` is None.") + + if prefix is not None: + keys = list(state_dict.keys()) + model_keys = [k for k in keys if k.startswith(f"{prefix}.")] + if len(model_keys) > 0: + state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in model_keys} + + if len(state_dict) > 0: + if adapter_name in getattr(self, "peft_config", {}): + raise ValueError( + f"Adapter name {adapter_name} already in use in the model - please select a new adapter name." + ) + + # check with first key if is not in peft format + if "lora_controlnet" in state_dict: + del state_dict["lora_controlnet"] + state_dict = convert_control_lora_state_dict_to_peft(state_dict) + + rank = {} + for key, val in state_dict.items(): + # Cannot figure out rank from lora layers that don't have atleast 2 dimensions. + # Bias layers in LoRA only have a single dimension + if "lora_B" in key and val.ndim > 1: + rank[key] = val.shape[1] + + if network_alphas is not None and len(network_alphas) >= 1: + alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")] + network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} + + lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) + lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs) + + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"]: + if is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<", "0.9.0"): + lora_config_kwargs.pop("use_dora") + + if "lora_bias" in lora_config_kwargs: + if lora_config_kwargs["lora_bias"]: + if is_peft_version("<=", "0.13.2"): + raise ValueError( + "You need `peft` 0.14.0 at least to use `lora_bias` in LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<=", "0.13.2"): + lora_config_kwargs.pop("lora_bias") + + lora_config_kwargs["bias"] = "all" + lora_config_kwargs["target_modules"] = self.TARGET_MODULES + lora_config_kwargs["modules_to_save"] = self.SAVE_MODULES + lora_config = LoraConfig(**lora_config_kwargs) + # adapter_name + if adapter_name is None: + adapter_name = "default" + + # =", "0.13.1"): + peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage + + # To handle scenarios where we cannot successfully set state dict. If it's unsucessful, + # we should also delete the `peft_config` associated to the `adapter_name`. + try: + inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs) + incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs) + except Exception as e: + # In case `inject_adapter_in_model()` was unsuccessful even before injecting the `peft_config`. + if hasattr(self, "peft_config"): + for module in self.modules(): + if isinstance(module, BaseTunerLayer): + active_adapters = module.active_adapters + for active_adapter in active_adapters: + if adapter_name in active_adapter: + module.delete_adapter(adapter_name) + + self.peft_config.pop(adapter_name) + logger.error(f"Loading {adapter_name} was unsucessful with the following error: \n{e}") + raise + + warn_msg = "" + if incompatible_keys is not None: + # Check only for unexpected keys. + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k] + if lora_unexpected_keys: + warn_msg = ( + f"Loading adapter weights from state_dict led to unexpected keys found in the model:" + f" {', '.join(lora_unexpected_keys)}. " + ) + + # Filter missing keys specific to the current adapter. + missing_keys = getattr(incompatible_keys, "missing_keys", None) + if missing_keys: + lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k] + if lora_missing_keys: + warn_msg += ( + f"Loading adapter weights from state_dict led to missing keys in the model:" + f" {', '.join(lora_missing_keys)}." + ) + + if warn_msg: + logger.warning(warn_msg) + + # Offload back. + if is_model_cpu_offload: + _pipeline.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + _pipeline.enable_sequential_cpu_offload() + # Unsafe code /> diff --git a/src/diffusers/models/controlnets/controlnet.py b/src/diffusers/models/controlnets/controlnet.py index f3784d51d3bd..58d3ab65b891 100644 --- a/src/diffusers/models/controlnets/controlnet.py +++ b/src/diffusers/models/controlnets/controlnet.py @@ -19,7 +19,7 @@ from torch.nn import functional as F from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import PeftAdapterMixin +from ...loaders import PeftAdapterMixin, ControlLoRAMixin from ...loaders.single_file_model import FromOriginalModelMixin from ...utils import BaseOutput, logging from ..attention_processor import ( @@ -109,7 +109,7 @@ def forward(self, conditioning): return embedding -class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin): +class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, ControlLoRAMixin): """ A ControlNet model. diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py index ed353ee6d7f9..322b118a6517 100644 --- a/src/diffusers/utils/state_dict_utils.py +++ b/src/diffusers/utils/state_dict_utils.py @@ -51,16 +51,55 @@ class StateDictType(enum.Enum): } CONTROL_LORA_TO_DIFFUSERS = { - ".to_out.up": ".to_out.0.lora_B", - ".to_out.down": ".to_out.0.lora_A", - ".to_q.down": ".to_q.lora_A", - ".to_q.up": ".to_q.lora_B", - ".to_k.down": ".to_k.lora_A", - ".to_k.up": ".to_k.lora_B", - ".to_v.down": ".to_v.lora_A", - ".to_v.up": ".to_v.lora_B", - ".down": ".lora_A", - ".up": ".lora_B", + ".to_q.bias": ".to_q.base_layer.bias", + ".to_q.down": ".to_q.lora_A.weight", + ".to_q.up": ".to_q.lora_B.weight", + ".to_k.bias": ".to_k.base_layer.bias", + ".to_k.down": ".to_k.lora_A.weight", + ".to_k.up": ".to_k.lora_B.weight", + ".to_v.bias": ".to_v.base_layer.bias", + ".to_v.down": ".to_v.lora_A.weight", + ".to_v.up": ".to_v.lora_B.weight", + ".to_out.0.bias": ".to_out.0.base_layer.bias", + ".to_out.0.down": ".to_out.0.lora_A.weight", + ".to_out.0.up": ".to_out.0.lora_B.weight", + ".ff.net.0.proj.bias": ".ff.net.0.proj.base_layer.bias", + ".ff.net.0.proj.down": ".ff.net.0.proj.lora_A.weight", + ".ff.net.0.proj.up": ".ff.net.0.proj.lora_B.weight", + ".ff.net.2.bias": ".ff.net.2.base_layer.bias", + ".ff.net.2.down": ".ff.net.2.lora_A.weight", + ".ff.net.2.up": ".ff.net.2.lora_B.weight", + ".proj_in.bias": ".proj_in.base_layer.bias", + ".proj_in.down": ".proj_in.lora_A.weight", + ".proj_in.up": ".proj_in.lora_B.weight", + ".proj_out.bias": ".proj_out.base_layer.bias", + ".proj_out.down": ".proj_out.lora_A.weight", + ".proj_out.up": ".proj_out.lora_B.weight", + ".conv.bias": ".conv.base_layer.bias", + ".conv.down": ".conv.lora_A.weight", + ".conv.up": ".conv.lora_B.weight", + **{f".conv{i}.bias": f".conv{i}.base_layer.bias" for i in range(1, 3)}, + **{f".conv{i}.down": f".conv{i}.lora_A.weight" for i in range(1, 3)}, + **{f".conv{i}.up": f".conv{i}.lora_B.weight" for i in range(1, 3)}, + "conv_in.bias": "conv_in.base_layer.bias", + "conv_in.down": "conv_in.lora_A.weight", + "conv_in.up": "conv_in.lora_B.weight", + ".conv_shortcut.bias": ".conv_shortcut.base_layer.bias", + ".conv_shortcut.down": ".conv_shortcut.lora_A.weight", + ".conv_shortcut.up": ".conv_shortcut.lora_B.weight", + **{f".linear_{i}.bias": f".linear_{i}.base_layer.bias" for i in range(1, 3)}, + **{f".linear_{i}.down": f".linear_{i}.lora_A.weight" for i in range(1, 3)}, + **{f".linear_{i}.up": f".linear_{i}.lora_B.weight" for i in range(1, 3)}, + "time_emb_proj.bias": "time_emb_proj.base_layer.bias", + "time_emb_proj.down": "time_emb_proj.lora_A.weight", + "time_emb_proj.up": "time_emb_proj.lora_B.weight", + "controlnet_cond_embedding.conv_in.bias": "controlnet_cond_embedding.conv_in.modules_to_save.bias", + "controlnet_cond_embedding.conv_out.bias": "controlnet_cond_embedding.conv_out.modules_to_save.bias", + **{f"controlnet_cond_embedding.blocks.{i}.bias": f"controlnet_cond_embedding.blocks.{i}.modules_to_save.bias" for i in range(6)}, + **{f"controlnet_down_blocks.{i}.bias": f"controlnet_down_blocks.{i}.modules_to_save.bias" for i in range(9)}, + "controlnet_mid_block.bias": "controlnet_mid_block.modules_to_save.bias", + ".norm.bias": ".norm.modules_to_save.bias", + **{f".norm{i}.bias": f".norm{i}.modules_to_save.bias" for i in range(1, 4)}, } DIFFUSERS_TO_PEFT = { @@ -280,10 +319,29 @@ def _convert_controlnet_to_diffusers(state_dict): } layers_per_block = 2 + # op blocks + op_blocks = [key for key in state_dict if "0.op" in key] + converted_state_dict = {} # Conv in layers for key in input_blocks[0]: - diffusers_key = key.replace("conv_in", "input_blocks.0.0") + diffusers_key = key.replace("input_blocks.0.0", "conv_in") + converted_state_dict[diffusers_key] = state_dict.get(key) + + # controlnet time embedding blocks + time_embedding_blocks = [key for key in state_dict if "time_embed" in key] + for key in time_embedding_blocks: + diffusers_key = (key.replace("time_embed.0", "time_embedding.linear_1") + .replace("time_embed.2", "time_embedding.linear_2") + ) + converted_state_dict[diffusers_key] = state_dict.get(key) + + # controlnet label embedding blocks + label_embedding_blocks = [key for key in state_dict if "label_emb" in key] + for key in label_embedding_blocks: + diffusers_key = (key.replace("label_emb.0.0", "add_embedding.linear_1") + .replace("label_emb.0.2", "add_embedding.linear_2") + ) converted_state_dict[diffusers_key] = state_dict.get(key) # Down blocks @@ -307,13 +365,10 @@ def _convert_controlnet_to_diffusers(state_dict): ) converted_state_dict[diffusers_key] = state_dict.get(key) - if f"input_blocks.{i}.0.op.weight" in state_dict: - converted_state_dict[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = state_dict.get( - f"input_blocks.{i}.0.op.weight" - ) - converted_state_dict[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = state_dict.get( - f"input_blocks.{i}.0.op.bias" - ) + if f"input_blocks.{i}.0.op.bias" in state_dict: + for key in [key for key in op_blocks if f"input_blocks.{i}.0.op" in key]: + diffusers_key = key.replace(f"input_blocks.{i}.0.op", f"down_blocks.{block_id}.downsamplers.0.conv") + converted_state_dict[diffusers_key] = state_dict.get(key) attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] if attentions: @@ -342,23 +397,23 @@ def _convert_controlnet_to_diffusers(state_dict): diffusers_key = max(key - 1, 0) if key % 2 == 0: for k in middle_blocks[key]: - diffusers_key = (k.replace("in_layers.0", "norm1") + diffusers_key_hf = (k.replace("in_layers.0", "norm1") .replace("in_layers.2", "conv1") .replace("out_layers.0", "norm2") .replace("out_layers.3", "conv2") .replace("emb_layers.1", "time_emb_proj") .replace("skip_connection", "conv_shortcut") ) - diffusers_key = diffusers_key.replace( - f"middle_block.{k}", f"mid_block.resnets.{diffusers_key}" + diffusers_key_hf = diffusers_key_hf.replace( + f"middle_block.{key}", f"mid_block.resnets.{diffusers_key}" ) - converted_state_dict[diffusers_key] = state_dict.get(k) + converted_state_dict[diffusers_key_hf] = state_dict.get(k) else: for k in middle_blocks[key]: - diffusers_key = k.replace( - f"middle_block.{k}", f"mid_block.attentions.{diffusers_key}" + diffusers_key_hf = k.replace( + f"middle_block.{key}", f"mid_block.attentions.{diffusers_key}" ) - converted_state_dict[diffusers_key] = state_dict.get(k) + converted_state_dict[diffusers_key_hf] = state_dict.get(k) # mid block converted_state_dict["controlnet_mid_block.weight"] = state_dict.get("middle_block_out.0.weight") @@ -383,6 +438,14 @@ def _convert_controlnet_to_diffusers(state_dict): f"input_hint_block.{cond_block_id}.bias" ) + for key in [key for key in state_dict if "input_hint_block.0" in key]: + diffusers_key = key.replace("input_hint_block.0", "controlnet_cond_embedding.conv_in") + converted_state_dict[diffusers_key] = state_dict.get(key) + + for key in [key for key in state_dict if "input_hint_block.14" in key]: + diffusers_key = key.replace(f"input_hint_block.14", "controlnet_cond_embedding.conv_out") + converted_state_dict[diffusers_key] = state_dict.get(key) + return converted_state_dict state_dict = _convert_controlnet_to_diffusers(state_dict) From de6122638502eabb8a5af07912ba0d9c14b753f1 Mon Sep 17 00:00:00 2001 From: lavinal712 Date: Sat, 15 Feb 2025 13:38:02 +0000 Subject: [PATCH 06/19] 1 --- src/diffusers/loaders/controlnet.py | 341 ---------- .../loaders/lora_conversion_utils.py | 159 +---- src/diffusers/loaders/peft.py | 1 - src/diffusers/loaders/single_file_utils.py | 144 ----- .../models/controlnets/controlnet_lora.py | 593 ------------------ .../pipelines/control_lora/control_lora.py | 1 - 6 files changed, 6 insertions(+), 1233 deletions(-) delete mode 100644 src/diffusers/loaders/controlnet.py delete mode 100644 src/diffusers/models/controlnets/controlnet_lora.py diff --git a/src/diffusers/loaders/controlnet.py b/src/diffusers/loaders/controlnet.py deleted file mode 100644 index 128ef6670796..000000000000 --- a/src/diffusers/loaders/controlnet.py +++ /dev/null @@ -1,341 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from contextlib import nullcontext -from typing import Dict, List, Optional, Union - -import torch -from huggingface_hub.utils import validate_hf_hub_args - -from ..loaders.single_file_utils import DIFFUSERS_TO_LDM_MAPPING -from ..models.embeddings import ( - ImageProjection, - MultiIPAdapterImageProjection, -) -from ..models.modeling_utils import load_model_dict_into_meta -from ..utils import ( - USE_PEFT_BACKEND, - deprecate, - get_submodule_by_name, - is_accelerate_available, - is_peft_available, - is_peft_version, - is_torch_version, - is_transformers_available, - is_transformers_version, - logging, -) -from .lora_base import ( # noqa - LORA_WEIGHT_NAME, - LORA_WEIGHT_NAME_SAFE, - LoraBaseMixin, - _fetch_state_dict, - _load_lora_into_text_encoder, -) -from .lora_conversion_utils import ( - _convert_bfl_flux_control_lora_to_diffusers, - _convert_hunyuan_video_lora_to_diffusers, - _convert_kohya_flux_lora_to_diffusers, - _convert_non_diffusers_lora_to_diffusers, - _convert_stabilityai_control_lora_to_diffusers, - _convert_xlabs_flux_lora_to_diffusers, - _maybe_map_sgm_blocks_to_diffusers, -) - - -_LOW_CPU_MEM_USAGE_DEFAULT_LORA = False -if is_torch_version(">=", "1.9.0"): - if ( - is_peft_available() - and is_peft_version(">=", "0.13.1") - and is_transformers_available() - and is_transformers_version(">", "4.45.2") - ): - _LOW_CPU_MEM_USAGE_DEFAULT_LORA = True - -logger = logging.get_logger(__name__) - -CONTROLNET_NAME = "controlnet" - - -class ControlNetLoadersMixin: - """ - Load layers into a [`ControlNetModel`]. - """ - - _lora_loadable_modules = ["controlnet"] - controlnet_name = CONTROLNET_NAME - _control_lora_supported_norm_keys = ["norm1", "norm2", "norm3"] - - @classmethod - @validate_hf_hub_args - def lora_state_dict( - self, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - return_alphas: bool = False, - **kwargs): - r""" - """ - - cache_dir = kwargs.pop("cache_dir", None) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", None) - token = kwargs.pop("token", None) - revision = kwargs.pop("revision", None) - subfolder = kwargs.pop("subfolder", None) - weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) - - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } - - state_dict = _fetch_state_dict( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - use_safetensors=use_safetensors, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - allow_pickle=allow_pickle, - ) - - is_stabilityai = "lora_controlnet" in state_dict and "input_blocks.11.0.in_layers.0.weight" not in state_dict - if is_stabilityai: - state_dict = _convert_stabilityai_control_lora_to_diffusers(state_dict) - return (state_dict, None) if return_alphas else state_dict - - raise ValueError - - def load_lora_weights( - self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs - ): - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) - if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - - # if a dict is passed, copy it instead of modifying it inplace - if isinstance(pretrained_model_name_or_path_or_dict, dict): - pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() - - # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict, network_alphas = self.lora_state_dict( - pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs - ) - - has_lora_keys = any("lora" in key for key in state_dict.keys()) - - # Control LoRAs also have norm keys - has_norm_keys = any( - norm_key in key for key in state_dict.keys() for norm_key in self._control_lora_supported_norm_keys - ) - - if not (has_lora_keys or has_norm_keys): - raise ValueError("Invalid LoRA checkpoint.") - - controlnet_lora_state_dict = { - k: state_dict.pop(k) for k in list(state_dict.keys()) if "lora" in k - } - controlnet_norm_state_dict = { - k: state_dict.pop(k) - for k in list(state_dict.keys()) - if any(norm_key in k for norm_key in self._control_lora_supported_norm_keys) - } - controlnet_others_state_dict = { - k: state_dict.pop(k) for k in list(state_dict.keys()) - } - - controlnet = self - - if len(controlnet_lora_state_dict) > 0: - self.load_lora_into_controlnet( - controlnet_lora_state_dict, - network_alphas=network_alphas, - controlnet=controlnet, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - ) - - if len(controlnet_norm_state_dict) > 0: - self._load_norm_into_controlnet( - controlnet_norm_state_dict, - controlnet=controlnet, - discard_original_layers=False, - ) - - if len(controlnet_others_state_dict) > 0: - self._load_others_into_controlnet( - controlnet_others_state_dict, - controlnet=controlnet, - discard_original_layers=False, - ) - - @classmethod - def load_lora_into_controlnet( - cls, state_dict, network_alphas, controlnet, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False - ): - if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - - # Load the layers corresponding to controlnet. - logger.info(f"Loading {cls}.") - controlnet.load_lora_adapter( - state_dict, - network_alphas=network_alphas, - adapter_name=adapter_name, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - ) - - @classmethod - def _load_norm_into_controlnet( - cls, - state_dict, - controlnet, - prefix=None, - discard_original_layers=False, - ) -> Dict[str, torch.Tensor]: - # Remove prefix if present - prefix = prefix or cls.controlnet_name - for key in list(state_dict.keys()): - if key.split(".")[0] == prefix: - state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key) - - # Find invalid keys - controlnet_state_dict = controlnet.state_dict() - controlnet_keys = set(controlnet_state_dict.keys()) - state_dict_keys = set(state_dict.keys()) - extra_keys = list(state_dict_keys - controlnet_keys) - - if extra_keys: - logger.warning( - f"Unsupported keys found in state dict when trying to load normalization layers into the controlnet. The following keys will be ignored:\n{extra_keys}." - ) - - for key in extra_keys: - state_dict.pop(key) - - # Save the layers that are going to be overwritten so that unload_lora_weights can work as expected - overwritten_layers_state_dict = {} - if not discard_original_layers: - for key in state_dict.keys(): - overwritten_layers_state_dict[key] = controlnet_state_dict[key].clone() - - logger.info( - "The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will directly update the state_dict of the controlnet " - 'as opposed to the LoRA layers that will co-exist separately until the "fuse_lora()" method is called. That is to say, the normalization layers will always be directly ' - "fused into the controlnet and can only be unfused if `discard_original_layers=True` is passed. This might also have implications when dealing with multiple LoRAs. " - "If you notice something unexpected, please open an issue: https://github.com/huggingface/diffusers/issues." - ) - - # We can't load with strict=True because the current state_dict does not contain all the controlnet keys - incompatible_keys = controlnet.load_state_dict(state_dict, strict=False) - unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) - - # We shouldn't expect to see the supported norm keys here being present in the unexpected keys. - if unexpected_keys: - if any(norm_key in k for k in unexpected_keys for norm_key in cls._control_lora_supported_norm_keys): - raise ValueError( - f"Found {unexpected_keys} as unexpected keys while trying to load norm layers into the controlnet." - ) - - return overwritten_layers_state_dict - - @classmethod - def _load_others_into_controlnet( - cls, - state_dict, - controlnet, - prefix=None, - discard_original_layers=False, - ) -> Dict[str, torch.Tensor]: - # Remove prefix if present - prefix = prefix or cls.controlnet_name - for key in list(state_dict.keys()): - if key.split(".")[0] == prefix: - state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key) - - # Find invalid keys - controlnet_state_dict = controlnet.state_dict() - controlnet_keys = set(controlnet_state_dict.keys()) - state_dict_keys = set(state_dict.keys()) - extra_keys = list(state_dict_keys - controlnet_keys) - - if extra_keys: - logger.warning( - f"Unsupported keys found in state dict when trying to load normalization layers into the controlnet. The following keys will be ignored:\n{extra_keys}." - ) - - for key in extra_keys: - state_dict.pop(key) - - # Save the layers that are going to be overwritten so that unload_lora_weights can work as expected - overwritten_layers_state_dict = {} - if not discard_original_layers: - for key in state_dict.keys(): - overwritten_layers_state_dict[key] = controlnet_state_dict[key].clone() - - logger.info( - "The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will directly update the state_dict of the transformer " - 'as opposed to the LoRA layers that will co-exist separately until the "fuse_lora()" method is called. That is to say, the normalization layers will always be directly ' - "fused into the controlnet and can only be unfused if `discard_original_layers=True` is passed. This might also have implications when dealing with multiple LoRAs. " - "If you notice something unexpected, please open an issue: https://github.com/huggingface/diffusers/issues." - ) - - # We can't load with strict=True because the current state_dict does not contain all the transformer keys - incompatible_keys = controlnet.load_state_dict(state_dict, strict=False) - unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) - - # We shouldn't expect to see the supported norm keys here being present in the unexpected keys. - if unexpected_keys: - if any(norm_key in k for k in unexpected_keys for norm_key in cls._control_lora_supported_norm_keys): - raise ValueError( - f"Found {unexpected_keys} as unexpected keys while trying to load norm layers into the transformer." - ) - - return overwritten_layers_state_dict - - def fuse_lora( - self, - components: List[str] = ["controlnet"], - lora_scale: float = 1.0, - safe_fusing: bool = False, - adapter_names: Optional[List[str]] = None, - **kwargs, - ): - super().fuse_lora( - components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names - ) - - def unfuse_lora(self, components: List[str] = ["controlnet"], **kwargs): - super().unfuse_lora(components=components) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index e297d5063026..5116747a3119 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -17,12 +17,12 @@ import torch from ..utils import is_peft_version, logging -from .single_file_utils import ( - DIFFUSERS_TO_LDM_MAPPING, - LDM_CONTROLNET_KEY, - update_unet_resnet_ldm_to_diffusers, - update_unet_attention_ldm_to_diffusers -) +# from .single_file_utils import ( +# DIFFUSERS_TO_LDM_MAPPING, +# LDM_CONTROLNET_KEY, +# update_unet_resnet_ldm_to_diffusers, +# update_unet_attention_ldm_to_diffusers +# ) logger = logging.get_logger(__name__) @@ -1154,150 +1154,3 @@ def remap_single_transformer_blocks_(key, state_dict): converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key) return converted_state_dict - -def _convert_stabilityai_control_lora_to_diffusers(checkpoint): - # Return checkpoint if it's already been converted - if "time_embedding.linear_1.weight" in checkpoint: - return checkpoint - # Some controlnet ckpt files are distributed independently from the rest of the - # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ - if "time_embed.0.weight" in checkpoint: - controlnet_state_dict = checkpoint - - else: - controlnet_state_dict = {} - keys = list(checkpoint.keys()) - controlnet_key = LDM_CONTROLNET_KEY - for key in keys: - if key.startswith(controlnet_key): - controlnet_state_dict[key.replace(controlnet_key, "")] = checkpoint.get(key) - else: - controlnet_state_dict[key] = checkpoint.get(key) - - new_checkpoint = {} - ldm_controlnet_keys = DIFFUSERS_TO_LDM_MAPPING["controlnet"]["layers"] - for diffusers_key, ldm_key in ldm_controlnet_keys.items(): - if ldm_key not in controlnet_state_dict: - continue - new_checkpoint[diffusers_key] = controlnet_state_dict[ldm_key] - for k, v in controlnet_state_dict.items(): - if "time_embed.0" in k: - new_checkpoint[k.replace("time_embed.0", "time_embedding.linear_1")] = v - elif "time_embed.2" in k: - new_checkpoint[k.replace("time_embed.2", "time_embedding.linear_2")] = v - elif "input_blocks.0.0" in k: - new_checkpoint[k.replace("input_blocks.0.0", "conv_in")] = v - elif "label_emb.0.0" in k: - new_checkpoint[k.replace("label_emb.0.0", "add_embedding.linear_1")] = v - elif "label_emb.0.2" in k: - new_checkpoint[k.replace("label_emb.0.2", "add_embedding.linear_2")] = v - elif "input_blocks.3.0.op" in k: - new_checkpoint[k.replace("input_blocks.3.0.op", "down_blocks.0.downsamplers.0.conv")] = v - elif "input_blocks.6.0.op" in k: - new_checkpoint[k.replace("input_blocks.6.0.op", "down_blocks.1.downsamplers.0.conv")] = v - - # Retrieves the keys for the input blocks only - num_input_blocks = len( - {".".join(layer.split(".")[:2]) for layer in controlnet_state_dict if "input_blocks" in layer} - ) - input_blocks = { - layer_id: [key for key in controlnet_state_dict if f"input_blocks.{layer_id}" in key] - for layer_id in range(num_input_blocks) - } - - # Down blocks - for i in range(1, num_input_blocks): - block_id = (i - 1) // (2 + 1) - layer_in_block_id = (i - 1) % (2 + 1) - - resnets = [ - key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key - ] - update_unet_resnet_ldm_to_diffusers( - resnets, - new_checkpoint, - controlnet_state_dict, - {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}, - ) - - if f"input_blocks.{i}.0.op.weight" in controlnet_state_dict: - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = controlnet_state_dict.get( - f"input_blocks.{i}.0.op.weight" - ) - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = controlnet_state_dict.get( - f"input_blocks.{i}.0.op.bias" - ) - - attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] - if attentions: - update_unet_attention_ldm_to_diffusers( - attentions, - new_checkpoint, - controlnet_state_dict, - {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}, - ) - - # controlnet down blocks - for i in range(num_input_blocks): - new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = controlnet_state_dict.get(f"zero_convs.{i}.0.weight") - new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = controlnet_state_dict.get(f"zero_convs.{i}.0.bias") - - # Retrieves the keys for the middle blocks only - num_middle_blocks = len( - {".".join(layer.split(".")[:2]) for layer in controlnet_state_dict if "middle_block" in layer} - ) - middle_blocks = { - layer_id: [key for key in controlnet_state_dict if f"middle_block.{layer_id}" in key] - for layer_id in range(num_middle_blocks) - } - - # Mid blocks - for key in middle_blocks.keys(): - diffusers_key = max(key - 1, 0) - if key % 2 == 0: - update_unet_resnet_ldm_to_diffusers( - middle_blocks[key], - new_checkpoint, - controlnet_state_dict, - mapping={"old": f"middle_block.{key}", "new": f"mid_block.resnets.{diffusers_key}"}, - ) - else: - update_unet_attention_ldm_to_diffusers( - middle_blocks[key], - new_checkpoint, - controlnet_state_dict, - mapping={"old": f"middle_block.{key}", "new": f"mid_block.attentions.{diffusers_key}"}, - ) - - # mid block - new_checkpoint["controlnet_mid_block.weight"] = controlnet_state_dict.get("middle_block_out.0.weight") - new_checkpoint["controlnet_mid_block.bias"] = controlnet_state_dict.get("middle_block_out.0.bias") - - # controlnet cond embedding blocks - cond_embedding_blocks = { - ".".join(layer.split(".")[:2]) - for layer in controlnet_state_dict - if "input_hint_block" in layer and ("input_hint_block.0" not in layer) and ("input_hint_block.14" not in layer) - } - num_cond_embedding_blocks = len(cond_embedding_blocks) - - for idx in range(1, num_cond_embedding_blocks + 1): - diffusers_idx = idx - 1 - cond_block_id = 2 * idx - - new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.weight"] = controlnet_state_dict.get( - f"input_hint_block.{cond_block_id}.weight" - ) - new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.bias"] = controlnet_state_dict.get( - f"input_hint_block.{cond_block_id}.bias" - ) - - new_new_checkpoint = {} - for k, v in controlnet_state_dict.items(): - if ".down" in k: - new_new_checkpoint[k.replace(".down", "lora_A.weight")] = v - elif ".up" in k: - new_new_checkpoint[k.replace(".up", "lora_B.weight")] = v - new_new_checkpoint[k] = v - - return new_new_checkpoint diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index a8ee406ae42d..4a1f7de575a2 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -21,7 +21,6 @@ import safetensors import torch -from ..loaders.lora_conversion_utils import _convert_stabilityai_control_lora_to_diffusers from ..utils import ( MIN_PEFT_VERSION, USE_PEFT_BACKEND, diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index ba00911b40ef..731b7b87f625 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -1311,150 +1311,6 @@ def convert_controlnet_checkpoint( return new_checkpoint -def convert_control_lora_checkpoint( - checkpoint, - config, - **kwargs, -): - # Return checkpoint if it's already been converted - if "time_embedding.linear_1.weight" in checkpoint: - return checkpoint - # Some controlnet ckpt files are distributed independently from the rest of the - # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ - if "time_embed.0.weight" in checkpoint: - controlnet_state_dict = checkpoint - - else: - controlnet_state_dict = {} - keys = list(checkpoint.keys()) - controlnet_key = LDM_CONTROLNET_KEY - for key in keys: - if key.startswith(controlnet_key): - controlnet_state_dict[key.replace(controlnet_key, "")] = checkpoint.get(key) - else: - controlnet_state_dict[key] = checkpoint.get(key) - - new_checkpoint = {} - ldm_controlnet_keys = DIFFUSERS_TO_LDM_MAPPING["controlnet"]["layers"] - for diffusers_key, ldm_key in ldm_controlnet_keys.items(): - if ldm_key not in controlnet_state_dict: - continue - new_checkpoint[diffusers_key] = controlnet_state_dict[ldm_key] - for k, v in controlnet_state_dict.items(): - if "time_embed.0" in k: - new_checkpoint[k.replace("time_embed.0", "time_embedding.linear_1")] = v - elif "time_embed.2" in k: - new_checkpoint[k.replace("time_embed.2", "time_embedding.linear_2")] = v - elif "input_blocks.0.0" in k: - new_checkpoint[k.replace("input_blocks.0.0", "conv_in")] = v - elif "label_emb.0.0" in k: - new_checkpoint[k.replace("label_emb.0.0", "add_embedding.linear_1")] = v - elif "label_emb.0.2" in k: - new_checkpoint[k.replace("label_emb.0.2", "add_embedding.linear_2")] = v - elif "input_blocks.3.0.op" in k: - new_checkpoint[k.replace("input_blocks.3.0.op", "down_blocks.0.downsamplers.0.conv")] = v - elif "input_blocks.6.0.op" in k: - new_checkpoint[k.replace("input_blocks.6.0.op", "down_blocks.1.downsamplers.0.conv")] = v - - # Retrieves the keys for the input blocks only - num_input_blocks = len( - {".".join(layer.split(".")[:2]) for layer in controlnet_state_dict if "input_blocks" in layer} - ) - input_blocks = { - layer_id: [key for key in controlnet_state_dict if f"input_blocks.{layer_id}" in key] - for layer_id in range(num_input_blocks) - } - - # Down blocks - for i in range(1, num_input_blocks): - block_id = (i - 1) // (config["layers_per_block"] + 1) - layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) - - resnets = [ - key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key - ] - update_unet_resnet_ldm_to_diffusers( - resnets, - new_checkpoint, - controlnet_state_dict, - {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}, - ) - - if f"input_blocks.{i}.0.op.weight" in controlnet_state_dict: - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = controlnet_state_dict.get( - f"input_blocks.{i}.0.op.weight" - ) - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = controlnet_state_dict.get( - f"input_blocks.{i}.0.op.bias" - ) - - attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] - if attentions: - update_unet_attention_ldm_to_diffusers( - attentions, - new_checkpoint, - controlnet_state_dict, - {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}, - ) - - # controlnet down blocks - for i in range(num_input_blocks): - new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = controlnet_state_dict.get(f"zero_convs.{i}.0.weight") - new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = controlnet_state_dict.get(f"zero_convs.{i}.0.bias") - - # Retrieves the keys for the middle blocks only - num_middle_blocks = len( - {".".join(layer.split(".")[:2]) for layer in controlnet_state_dict if "middle_block" in layer} - ) - middle_blocks = { - layer_id: [key for key in controlnet_state_dict if f"middle_block.{layer_id}" in key] - for layer_id in range(num_middle_blocks) - } - - # Mid blocks - for key in middle_blocks.keys(): - diffusers_key = max(key - 1, 0) - if key % 2 == 0: - update_unet_resnet_ldm_to_diffusers( - middle_blocks[key], - new_checkpoint, - controlnet_state_dict, - mapping={"old": f"middle_block.{key}", "new": f"mid_block.resnets.{diffusers_key}"}, - ) - else: - update_unet_attention_ldm_to_diffusers( - middle_blocks[key], - new_checkpoint, - controlnet_state_dict, - mapping={"old": f"middle_block.{key}", "new": f"mid_block.attentions.{diffusers_key}"}, - ) - - # mid block - new_checkpoint["controlnet_mid_block.weight"] = controlnet_state_dict.get("middle_block_out.0.weight") - new_checkpoint["controlnet_mid_block.bias"] = controlnet_state_dict.get("middle_block_out.0.bias") - - # controlnet cond embedding blocks - cond_embedding_blocks = { - ".".join(layer.split(".")[:2]) - for layer in controlnet_state_dict - if "input_hint_block" in layer and ("input_hint_block.0" not in layer) and ("input_hint_block.14" not in layer) - } - num_cond_embedding_blocks = len(cond_embedding_blocks) - - for idx in range(1, num_cond_embedding_blocks + 1): - diffusers_idx = idx - 1 - cond_block_id = 2 * idx - - new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.weight"] = controlnet_state_dict.get( - f"input_hint_block.{cond_block_id}.weight" - ) - new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.bias"] = controlnet_state_dict.get( - f"input_hint_block.{cond_block_id}.bias" - ) - - return new_checkpoint - - def convert_ldm_vae_checkpoint(checkpoint, config): # extract state dict for VAE # remove the LDM_VAE_KEY prefix from the ldm checkpoint keys so that it is easier to map them to diffusers keys diff --git a/src/diffusers/models/controlnets/controlnet_lora.py b/src/diffusers/models/controlnets/controlnet_lora.py deleted file mode 100644 index 65f94cf65a56..000000000000 --- a/src/diffusers/models/controlnets/controlnet_lora.py +++ /dev/null @@ -1,593 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import math -from contextlib import contextmanager -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch -from torch import nn -from torch.nn import functional as F - -from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders.single_file_model import FromOriginalModelMixin -from ...utils import BaseOutput, logging -from ..attention_processor import ( - ADDED_KV_ATTENTION_PROCESSORS, - CROSS_ATTENTION_PROCESSORS, - AttentionProcessor, - AttnAddedKVProcessor, - AttnProcessor, -) -from ..embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps -from ..modeling_utils import ModelMixin -from ..unets.unet_2d_blocks import ( - CrossAttnDownBlock2D, - DownBlock2D, - UNetMidBlock2D, - UNetMidBlock2DCrossAttn, - get_down_block, -) -from ..unets.unet_2d_condition import UNet2DConditionModel -from .controlnet import ControlNetConditioningEmbedding, ControlNetModel, zero_module - -from ...models.lora import LoRACompatibleConv, LoRACompatibleLinear, LoRAConv2dLayer, LoRALinearLayer -from ...loaders.single_file_utils import load_single_file_checkpoint, convert_controlnet_checkpoint, convert_control_lora_checkpoint - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -@dataclass -class ControlNetLoRAOutput(BaseOutput): - """ - The output of [`ControlNetLoRAModel`]. - - Args: - down_block_res_samples (`tuple[torch.Tensor]`): - A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should - be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be - used to condition the original UNet's downsampling activations. - mid_down_block_re_sample (`torch.Tensor`): - The activation of the middle block (the lowest sample resolution). Each tensor should be of shape - `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`. - Output can be used to condition the original UNet's middle block activation. - """ - - down_block_res_samples: Tuple[torch.Tensor] - mid_block_res_sample: torch.Tensor - - -class ControlNetLoRAModel(ControlNetModel): - @register_to_config - def __init__( - self, - in_channels: int = 4, - conditioning_channels: int = 3, - flip_sin_to_cos: bool = True, - freq_shift: int = 0, - down_block_types: Tuple[str, ...] = ( - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "DownBlock2D", - ), - mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", - only_cross_attention: Union[bool, Tuple[bool]] = False, - block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), - layers_per_block: int = 2, - downsample_padding: int = 1, - mid_block_scale_factor: float = 1, - act_fn: str = "silu", - norm_num_groups: Optional[int] = 32, - norm_eps: float = 1e-5, - cross_attention_dim: int = 1280, - transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1, - encoder_hid_dim: Optional[int] = None, - encoder_hid_dim_type: Optional[str] = None, - attention_head_dim: Union[int, Tuple[int, ...]] = 8, - num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None, - use_linear_projection: bool = False, - class_embed_type: Optional[str] = None, - addition_embed_type: Optional[str] = None, - addition_time_embed_dim: Optional[int] = None, - num_class_embeds: Optional[int] = None, - upcast_attention: bool = False, - resnet_time_scale_shift: str = "default", - projection_class_embeddings_input_dim: Optional[int] = None, - controlnet_conditioning_channel_order: str = "rgb", - conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), - global_pool_conditions: bool = False, - addition_embed_type_num_heads: int = 64, - ): - super().__init__() - - # If `num_attention_heads` is not defined (which is the case for most models) - # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. - # The reason for this behavior is to correct for incorrectly named variables that were introduced - # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 - # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking - # which is why we correct for the naming here. - num_attention_heads = num_attention_heads or attention_head_dim - - # Check inputs - if len(block_out_channels) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." - ) - - if isinstance(transformer_layers_per_block, int): - transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) - - # input - conv_in_kernel = 3 - conv_in_padding = (conv_in_kernel - 1) // 2 - self.conv_in = nn.Conv2d( - in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding - ) - - # time - time_embed_dim = block_out_channels[0] * 4 - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) - timestep_input_dim = block_out_channels[0] - self.time_embedding = TimestepEmbedding( - timestep_input_dim, - time_embed_dim, - act_fn=act_fn, - ) - - if encoder_hid_dim_type is None and encoder_hid_dim is not None: - encoder_hid_dim_type = "text_proj" - self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) - logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") - - if encoder_hid_dim is None and encoder_hid_dim_type is not None: - raise ValueError( - f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." - ) - - if encoder_hid_dim_type == "text_proj": - self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) - elif encoder_hid_dim_type == "text_image_proj": - # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much - # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use - # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)` - self.encoder_hid_proj = TextImageProjection( - text_embed_dim=encoder_hid_dim, - image_embed_dim=cross_attention_dim, - cross_attention_dim=cross_attention_dim, - ) - - elif encoder_hid_dim_type is not None: - raise ValueError( - f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." - ) - else: - self.encoder_hid_proj = None - - # class embedding - if class_embed_type is None and num_class_embeds is not None: - self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) - elif class_embed_type == "timestep": - self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) - elif class_embed_type == "identity": - self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) - elif class_embed_type == "projection": - if projection_class_embeddings_input_dim is None: - raise ValueError( - "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" - ) - # The projection `class_embed_type` is the same as the timestep `class_embed_type` except - # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings - # 2. it projects from an arbitrary input dimension. - # - # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. - # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. - # As a result, `TimestepEmbedding` can be passed arbitrary vectors. - self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) - else: - self.class_embedding = None - - if addition_embed_type == "text": - if encoder_hid_dim is not None: - text_time_embedding_from_dim = encoder_hid_dim - else: - text_time_embedding_from_dim = cross_attention_dim - - self.add_embedding = TextTimeEmbedding( - text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads - ) - elif addition_embed_type == "text_image": - # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much - # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use - # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)` - self.add_embedding = TextImageTimeEmbedding( - text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim - ) - elif addition_embed_type == "text_time": - self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) - self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) - - elif addition_embed_type is not None: - raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") - - # control net conditioning embedding - self.controlnet_cond_embedding = ControlNetConditioningEmbedding( - conditioning_embedding_channels=block_out_channels[0], - block_out_channels=conditioning_embedding_out_channels, - conditioning_channels=conditioning_channels, - ) - - self.down_blocks = nn.ModuleList([]) - self.controlnet_down_blocks = nn.ModuleList([]) - - if isinstance(only_cross_attention, bool): - only_cross_attention = [only_cross_attention] * len(down_block_types) - - if isinstance(attention_head_dim, int): - attention_head_dim = (attention_head_dim,) * len(down_block_types) - - if isinstance(num_attention_heads, int): - num_attention_heads = (num_attention_heads,) * len(down_block_types) - - # down - output_channel = block_out_channels[0] - - controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) - controlnet_block = zero_module(controlnet_block) - self.controlnet_down_blocks.append(controlnet_block) - - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - - down_block = get_down_block( - down_block_type, - num_layers=layers_per_block, - transformer_layers_per_block=transformer_layers_per_block[i], - in_channels=input_channel, - out_channels=output_channel, - temb_channels=time_embed_dim, - add_downsample=not is_final_block, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, - num_attention_heads=num_attention_heads[i], - attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, - downsample_padding=downsample_padding, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention[i], - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - self.down_blocks.append(down_block) - - for _ in range(layers_per_block): - controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) - controlnet_block = zero_module(controlnet_block) - self.controlnet_down_blocks.append(controlnet_block) - - if not is_final_block: - controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) - controlnet_block = zero_module(controlnet_block) - self.controlnet_down_blocks.append(controlnet_block) - - # mid - mid_block_channel = block_out_channels[-1] - - controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1) - controlnet_block = zero_module(controlnet_block) - self.controlnet_mid_block = controlnet_block - - if mid_block_type == "UNetMidBlock2DCrossAttn": - self.mid_block = UNetMidBlock2DCrossAttn( - transformer_layers_per_block=transformer_layers_per_block[-1], - in_channels=mid_block_channel, - temb_channels=time_embed_dim, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - resnet_time_scale_shift=resnet_time_scale_shift, - cross_attention_dim=cross_attention_dim, - num_attention_heads=num_attention_heads[-1], - resnet_groups=norm_num_groups, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - ) - elif mid_block_type == "UNetMidBlock2D": - self.mid_block = UNetMidBlock2D( - in_channels=block_out_channels[-1], - temb_channels=time_embed_dim, - num_layers=0, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - resnet_groups=norm_num_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - add_attention=False, - ) - else: - raise ValueError(f"unknown mid_block_type : {mid_block_type}") - - for name, module in self.named_modules(): - *parent_module_path, attr_name = name.split(".") - parent_module = self - for path_part in parent_module_path: - parent_module = getattr(parent_module, path_part) - - if isinstance(module, nn.Linear): - module = LinearWithLoRA( - in_features=module.in_features, - out_features=module.out_features, - bias=False if module.bias is None else True, - ) - setattr(parent_module, attr_name, module) - elif isinstance(module, nn.Conv2d): - module = Conv2dWithLoRA( - in_channels=module.in_channels, - out_channels=module.out_channels, - kernel_size=module.kernel_size, - stride=module.stride, - padding=module.padding, - bias=False if module.bias is None else True - ) - setattr(parent_module, attr_name, module) - - @classmethod - def from_unet_and_single_file( - cls, - unet: UNet2DConditionModel, - controlnet_conditioning_channel_order: str = "rgb", - conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), - load_weights_from_unet: bool = True, - conditioning_channels: int = 3, - pretrained_model_link_or_path_or_dict: Optional[str] = None, - ): - r""" - Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`]. - - Parameters: - unet (`UNet2DConditionModel`): - The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied - where applicable. - """ - transformer_layers_per_block = ( - unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1 - ) - encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None - encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None - addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None - addition_time_embed_dim = ( - unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None - ) - - controllora = cls( - encoder_hid_dim=encoder_hid_dim, - encoder_hid_dim_type=encoder_hid_dim_type, - addition_embed_type=addition_embed_type, - addition_time_embed_dim=addition_time_embed_dim, - transformer_layers_per_block=transformer_layers_per_block, - in_channels=unet.config.in_channels, - flip_sin_to_cos=unet.config.flip_sin_to_cos, - freq_shift=unet.config.freq_shift, - down_block_types=unet.config.down_block_types, - only_cross_attention=unet.config.only_cross_attention, - block_out_channels=unet.config.block_out_channels, - layers_per_block=unet.config.layers_per_block, - downsample_padding=unet.config.downsample_padding, - mid_block_scale_factor=unet.config.mid_block_scale_factor, - act_fn=unet.config.act_fn, - norm_num_groups=unet.config.norm_num_groups, - norm_eps=unet.config.norm_eps, - cross_attention_dim=unet.config.cross_attention_dim, - attention_head_dim=unet.config.attention_head_dim, - num_attention_heads=unet.config.num_attention_heads, - use_linear_projection=unet.config.use_linear_projection, - class_embed_type=unet.config.class_embed_type, - num_class_embeds=unet.config.num_class_embeds, - upcast_attention=unet.config.upcast_attention, - resnet_time_scale_shift=unet.config.resnet_time_scale_shift, - projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, - mid_block_type=unet.config.mid_block_type, - controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, - conditioning_embedding_out_channels=conditioning_embedding_out_channels, - conditioning_channels=conditioning_channels, - ) - - if load_weights_from_unet: - controllora.conv_in.load_state_dict(unet.conv_in.state_dict()) - controllora.time_proj.load_state_dict(unet.time_proj.state_dict()) - controllora.time_embedding.load_state_dict(unet.time_embedding.state_dict()) - - if controllora.class_embedding: - controllora.class_embedding.load_state_dict(unet.class_embedding.state_dict()) - - if hasattr(controllora, "add_embedding"): - controllora.add_embedding.load_state_dict(unet.add_embedding.state_dict()) - - controllora.down_blocks.load_state_dict(unet.down_blocks.state_dict()) - controllora.mid_block.load_state_dict(unet.mid_block.state_dict()) - - if isinstance(pretrained_model_link_or_path_or_dict, dict): - checkpoint = pretrained_model_link_or_path_or_dict - elif isinstance(pretrained_model_link_or_path_or_dict, str): - checkpoint = load_single_file_checkpoint( - pretrained_model_link_or_path_or_dict, - ) - else: - raise ValueError - - config = ControlNetModel.load_config("xinsir/controlnet-canny-sdxl-1.0") - checkpoint = convert_control_lora_checkpoint(checkpoint, config) - - for name, param in checkpoint.items(): - *parent_module_path, attr_name = name.split(".") - parent_module = controllora - for path_part in parent_module_path: - parent_module = getattr(parent_module, path_part) - - if getattr(parent_module, attr_name, None) is None: - setattr(parent_module, attr_name, param.to(controllora.device)) - missing, unexpected = controllora.load_state_dict(checkpoint, strict=False) - # print("missing: ", missing) - # print("unexpected: ", unexpected) - - return controllora - - -class LinearWithLoRA(nn.Module): - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - device=None, - dtype=None, - ) -> None: - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.in_features = in_features - self.out_features = out_features - self.weight = nn.Parameter( - torch.empty((out_features, in_features), **factory_kwargs) - ) - if bias: - self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs)) - else: - self.register_parameter("bias", None) - self.reset_parameters() - - self.up = None - self.down = None - - def reset_parameters(self) -> None: - # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with - # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see - # https://github.com/pytorch/pytorch/issues/57109 - nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) - if self.bias is not None: - fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) - bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - nn.init.uniform_(self.bias, -bound, bound) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - if self.up is not None and self.down is not None: - weight = self.weight + torch.mm(self.up.to(self.weight.device), self.down.to(self.weight.device)) - return F.linear(input, weight, self.bias) - else: - return F.linear(input, self.weight, self.bias) - - def extra_repr(self) -> str: - return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}, up={self.up is not None}, down={self.down is not None}" - - -class Conv2dWithLoRA(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: Union[int, Tuple[int, int]], - stride: Union[int, Tuple[int, int]] = 1, - padding: Union[str, Union[int, Tuple[int, int]]] = 0, - dilation: Union[int, Tuple[int, int]] = 1, - groups: int = 1, - bias: bool = True, - padding_mode: str = "zeros", # TODO: refine this type - device=None, - dtype=None, - ) -> None: - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.stride = stride - self.padding = padding - self.dilation = dilation - self.groups = groups - self.padding_mode = padding_mode - - self.weight = nn.Parameter( - torch.empty( - (out_channels, in_channels // groups, *kernel_size), - **factory_kwargs, - ) - ) - if bias: - self.bias = nn.Parameter(torch.empty(out_channels, **factory_kwargs)) - else: - self.register_parameter("bias", None) - - self.reset_parameters() - - self.up = None - self.down = None - - def reset_parameters(self) -> None: - # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with - # uniform(-1/sqrt(k), 1/sqrt(k)), where k = weight.size(1) * prod(*kernel_size) - # For more details see: https://github.com/pytorch/pytorch/issues/15314#issuecomment-477448573 - nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) - if self.bias is not None: - fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) - if fan_in != 0: - bound = 1 / math.sqrt(fan_in) - nn.init.uniform_(self.bias, -bound, bound) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - if self.up is not None and self.down is not None: - weight = self.weight + torch.mm(self.up.flatten(1).to(self.weight.device), self.down.flatten(1).to(self.weight.device)).reshape(self.weight.shape) - return F.conv2d( - input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups - ) - else: - return F.conv2d( - input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups - ) - - def extra_repr(self): - s = ( - "{in_channels}, {out_channels}, kernel_size={kernel_size}" - ", stride={stride}" - ) - if self.padding != (0,) * len(self.padding): - s += ", padding={padding}" - if self.bias is None: - s += ", bias=False" - if self.up is not None: - s += ", up=True" - if self.down is not None: - s += ", down=True" - return s.format(**self.__dict__) - - def __setstate__(self, state): - super().__setstate__(state) - if not hasattr(self, "padding_mode"): - self.padding_mode = "zeros" - - -if __name__ == "__main__": - pass diff --git a/src/diffusers/pipelines/control_lora/control_lora.py b/src/diffusers/pipelines/control_lora/control_lora.py index c1a7512e62d0..8ffc735d8d16 100644 --- a/src/diffusers/pipelines/control_lora/control_lora.py +++ b/src/diffusers/pipelines/control_lora/control_lora.py @@ -9,7 +9,6 @@ import torch pipe_id = "stabilityai/stable-diffusion-xl-base-1.0" - controlnet_id = "xinsir/controlnet-canny-sdxl-1.0" lora_id = "stabilityai/control-lora" lora_filename = "control-LoRAs-rank128/control-lora-canny-rank128.safetensors" From 10daac7e198601cc170dbf7dc122a8e1cb03d984 Mon Sep 17 00:00:00 2001 From: lavinal712 Date: Sat, 15 Feb 2025 13:41:30 +0000 Subject: [PATCH 07/19] 1 --- src/diffusers/loaders/lora_conversion_utils.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 5116747a3119..e064aeba43b6 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -17,12 +17,6 @@ import torch from ..utils import is_peft_version, logging -# from .single_file_utils import ( -# DIFFUSERS_TO_LDM_MAPPING, -# LDM_CONTROLNET_KEY, -# update_unet_resnet_ldm_to_diffusers, -# update_unet_attention_ldm_to_diffusers -# ) logger = logging.get_logger(__name__) From 523967f396d84fe2333db21efc4055fd9bac9e63 Mon Sep 17 00:00:00 2001 From: lavinal712 Date: Sat, 15 Feb 2025 13:45:51 +0000 Subject: [PATCH 08/19] 1 --- .../pipelines/control_lora/__init__.py | 66 - .../pipelines/control_lora/control_lora.py | 57 - .../pipeline_control_lora_sd_xl.py | 1272 ----------------- 3 files changed, 1395 deletions(-) delete mode 100644 src/diffusers/pipelines/control_lora/__init__.py delete mode 100644 src/diffusers/pipelines/control_lora/control_lora.py delete mode 100644 src/diffusers/pipelines/control_lora/pipeline_control_lora_sd_xl.py diff --git a/src/diffusers/pipelines/control_lora/__init__.py b/src/diffusers/pipelines/control_lora/__init__.py deleted file mode 100644 index caeac8ac61ec..000000000000 --- a/src/diffusers/pipelines/control_lora/__init__.py +++ /dev/null @@ -1,66 +0,0 @@ -from typing import TYPE_CHECKING - -from ...utils import ( - DIFFUSERS_SLOW_IMPORT, - OptionalDependencyNotAvailable, - _LazyModule, - get_objects_from_module, - is_flax_available, - is_torch_available, - is_transformers_available, -) - - -_dummy_objects = {} -_import_structure = {} - -try: - if not (is_transformers_available() and is_torch_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ...utils import dummy_torch_and_transformers_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) -else: - _import_structure["pipeline_control_lora_sd_xl"] = ["StableDiffusionXLControlLoRAPipeline"] -try: - if not (is_transformers_available() and is_flax_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ...utils import dummy_flax_and_transformers_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects)) -else: - pass # _import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"] - - -if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: - try: - if not (is_transformers_available() and is_torch_available()): - raise OptionalDependencyNotAvailable() - - except OptionalDependencyNotAvailable: - from ...utils.dummy_torch_and_transformers_objects import * - else: - from .pipeline_control_lora_sd_xl import StableDiffusionXLControlNetPipeline - - try: - if not (is_transformers_available() and is_flax_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from ...utils.dummy_flax_and_transformers_objects import * # noqa F403 - else: - pass # from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline - - -else: - import sys - - sys.modules[__name__] = _LazyModule( - __name__, - globals()["__file__"], - _import_structure, - module_spec=__spec__, - ) - for name, value in _dummy_objects.items(): - setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/control_lora/control_lora.py b/src/diffusers/pipelines/control_lora/control_lora.py deleted file mode 100644 index 8ffc735d8d16..000000000000 --- a/src/diffusers/pipelines/control_lora/control_lora.py +++ /dev/null @@ -1,57 +0,0 @@ - - -if __name__ == "__main__": - from diffusers import ( - StableDiffusionXLControlNetPipeline, - ControlNetModel, - UNet2DConditionModel, - ) - import torch - - pipe_id = "stabilityai/stable-diffusion-xl-base-1.0" - lora_id = "stabilityai/control-lora" - lora_filename = "control-LoRAs-rank128/control-lora-canny-rank128.safetensors" - - - unet = UNet2DConditionModel.from_pretrained(pipe_id, subfolder="unet", torch_dtype=torch.float16).to("cuda") - controlnet = ControlNetModel.from_unet(unet).to(device="cuda", dtype=torch.float16) - controlnet.load_lora_adapter(lora_id, weight_name=lora_filename, controlnet_config=controlnet.config) - - from diffusers import AutoencoderKL - from diffusers.utils import load_image, make_image_grid - from PIL import Image - import numpy as np - import cv2 - - prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" - negative_prompt = "low quality, bad quality, sketches" - - image = load_image("https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png") - - controlnet_conditioning_scale = 1.0 # recommended for good generalization - - vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) - pipe = StableDiffusionXLControlNetPipeline.from_pretrained( - pipe_id, - unet=unet, - controlnet=controlnet, - vae=vae, - torch_dtype=torch.float16, - safety_checker=None, - ).to("cuda") - - image = np.array(image) - image = cv2.Canny(image, 100, 200) - image = image[:, :, None] - image = np.concatenate([image, image, image], axis=2) - image = Image.fromarray(image) - - images = pipe( - prompt, negative_prompt=negative_prompt, image=image, - controlnet_conditioning_scale=controlnet_conditioning_scale, - num_images_per_prompt=4 - ).images - - final_image = [image] + images - grid = make_image_grid(final_image, 1, 5) - grid.save(f"hf-logo1.png") diff --git a/src/diffusers/pipelines/control_lora/pipeline_control_lora_sd_xl.py b/src/diffusers/pipelines/control_lora/pipeline_control_lora_sd_xl.py deleted file mode 100644 index 033933029907..000000000000 --- a/src/diffusers/pipelines/control_lora/pipeline_control_lora_sd_xl.py +++ /dev/null @@ -1,1272 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import inspect -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -import numpy as np -import PIL.Image -import torch -import torch.nn.functional as F -from transformers import ( - CLIPImageProcessor, - CLIPTextModel, - CLIPTextModelWithProjection, - CLIPTokenizer, - CLIPVisionModelWithProjection, -) - -from diffusers.utils.import_utils import is_invisible_watermark_available - -from ...callbacks import MultiPipelineCallbacks, PipelineCallback -from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import ( - FromSingleFileMixin, - IPAdapterMixin, - StableDiffusionXLLoraLoaderMixin, - TextualInversionLoaderMixin, -) -from ...models import AutoencoderKL, UNet2DConditionModel -from ...models.attention_processor import ( - AttnProcessor2_0, - XFormersAttnProcessor, -) -from ...models.lora import adjust_lora_scale_text_encoder -from ...schedulers import KarrasDiffusionSchedulers -from ...utils import ( - USE_PEFT_BACKEND, - deprecate, - logging, - replace_example_docstring, - scale_lora_layers, - unscale_lora_layers, -) -from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor -from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin -from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput - - -if is_invisible_watermark_available(): - from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker - - -from ...utils import is_torch_xla_available - -from ...models.controlnets.controlnet_lora import ControlNetLoRAModel - - -if is_torch_xla_available(): - import torch_xla.core.xla_model as xm - - XLA_AVAILABLE = True -else: - XLA_AVAILABLE = False - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -EXAMPLE_DOC_STRING = """ - Examples: - ```py - >>> # !pip install opencv-python transformers accelerate - >>> from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, AutoencoderKL - >>> from diffusers.utils import load_image - >>> import numpy as np - >>> import torch - - >>> import cv2 - >>> from PIL import Image - - >>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" - >>> negative_prompt = "low quality, bad quality, sketches" - - >>> # download an image - >>> image = load_image( - ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png" - ... ) - - >>> # initialize the models and pipeline - >>> controlnet_conditioning_scale = 0.5 # recommended for good generalization - >>> controlnet = ControlNetModel.from_pretrained( - ... "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16 - ... ) - >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) - >>> pipe = StableDiffusionXLControlNetPipeline.from_pretrained( - ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, vae=vae, torch_dtype=torch.float16 - ... ) - >>> pipe.enable_model_cpu_offload() - - >>> # get canny image - >>> image = np.array(image) - >>> image = cv2.Canny(image, 100, 200) - >>> image = image[:, :, None] - >>> image = np.concatenate([image, image, image], axis=2) - >>> canny_image = Image.fromarray(image) - - >>> # generate image - >>> image = pipe( - ... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image - ... ).images[0] - ``` -""" - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps -def retrieve_timesteps( - scheduler, - num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, - timesteps: Optional[List[int]] = None, - sigmas: Optional[List[float]] = None, - **kwargs, -): - r""" - Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles - custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. - - Args: - scheduler (`SchedulerMixin`): - The scheduler to get timesteps from. - num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` - must be `None`. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - timesteps (`List[int]`, *optional*): - Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, - `num_inference_steps` and `sigmas` must be `None`. - sigmas (`List[float]`, *optional*): - Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, - `num_inference_steps` and `timesteps` must be `None`. - - Returns: - `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the - second element is the number of inference steps. - """ - if timesteps is not None and sigmas is not None: - raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") - if timesteps is not None: - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accepts_timesteps: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" timestep schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - elif sigmas is not None: - accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accept_sigmas: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" sigmas schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - else: - scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) - timesteps = scheduler.timesteps - return timesteps, num_inference_steps - - -class StableDiffusionXLControlLoRAPipeline( - DiffusionPipeline, - StableDiffusionMixin, - TextualInversionLoaderMixin, - StableDiffusionXLLoraLoaderMixin, - FromSingleFileMixin, -): - r""" - """ - - # leave controlnet out on purpose because it iterates with unet - model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" - _optional_components = [ - "tokenizer", - "tokenizer_2", - "text_encoder", - "text_encoder_2", - "feature_extractor", - ] - _callback_tensor_inputs = [ - "latents", - "prompt_embeds", - "negative_prompt_embeds", - "add_text_embeds", - "add_time_ids", - "negative_pooled_prompt_embeds", - "negative_add_time_ids", - ] - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - text_encoder_2: CLIPTextModelWithProjection, - tokenizer: CLIPTokenizer, - tokenizer_2: CLIPTokenizer, - unet: UNet2DConditionModel, - controlnet: ControlNetLoRAModel, - scheduler: KarrasDiffusionSchedulers, - force_zeros_for_empty_prompt: bool = True, - add_watermarker: Optional[bool] = None, - feature_extractor: CLIPImageProcessor = None, - ): - super().__init__() - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - text_encoder_2=text_encoder_2, - tokenizer=tokenizer, - tokenizer_2=tokenizer_2, - unet=unet, - controlnet=controlnet, - scheduler=scheduler, - feature_extractor=feature_extractor, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) - self.control_image_processor = VaeImageProcessor( - vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False - ) - add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() - - if add_watermarker: - self.watermark = StableDiffusionXLWatermarker() - else: - self.watermark = None - - self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt - def encode_prompt( - self, - prompt: str, - prompt_2: Optional[str] = None, - device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, - do_classifier_free_guidance: bool = True, - negative_prompt: Optional[str] = None, - negative_prompt_2: Optional[str] = None, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - pooled_prompt_embeds: Optional[torch.Tensor] = None, - negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, - lora_scale: Optional[float] = None, - clip_skip: Optional[int] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in both text-encoders - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - negative_prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and - `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - pooled_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` - input argument. - lora_scale (`float`, *optional*): - A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - clip_skip (`int`, *optional*): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - """ - device = device or self._execution_device - - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if self.text_encoder is not None: - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) - - if self.text_encoder_2 is not None: - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - else: - scale_lora_layers(self.text_encoder_2, lora_scale) - - prompt = [prompt] if isinstance(prompt, str) else prompt - - if prompt is not None: - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - # Define tokenizers and text encoders - tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] - text_encoders = ( - [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] - ) - - if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - - # textual inversion: process multi-vector tokens if necessary - prompt_embeds_list = [] - prompts = [prompt, prompt_2] - for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, tokenizer) - - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {tokenizer.model_max_length} tokens: {removed_text}" - ) - - prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) - - # We are only ALWAYS interested in the pooled output of the final text encoder - if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: - pooled_prompt_embeds = prompt_embeds[0] - - if clip_skip is None: - prompt_embeds = prompt_embeds.hidden_states[-2] - else: - # "2" because SDXL always indexes from the penultimate layer. - prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] - - prompt_embeds_list.append(prompt_embeds) - - prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) - - # get unconditional embeddings for classifier free guidance - zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt - if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: - negative_prompt_embeds = torch.zeros_like(prompt_embeds) - negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) - elif do_classifier_free_guidance and negative_prompt_embeds is None: - negative_prompt = negative_prompt or "" - negative_prompt_2 = negative_prompt_2 or negative_prompt - - # normalize str to list - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_2 = ( - batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 - ) - - uncond_tokens: List[str] - if prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = [negative_prompt, negative_prompt_2] - - negative_prompt_embeds_list = [] - for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): - if isinstance(self, TextualInversionLoaderMixin): - negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) - - max_length = prompt_embeds.shape[1] - uncond_input = tokenizer( - negative_prompt, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - negative_prompt_embeds = text_encoder( - uncond_input.input_ids.to(device), - output_hidden_states=True, - ) - - # We are only ALWAYS interested in the pooled output of the final text encoder - if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: - negative_pooled_prompt_embeds = negative_prompt_embeds[0] - negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] - - negative_prompt_embeds_list.append(negative_prompt_embeds) - - negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - - if self.text_encoder_2 is not None: - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) - else: - prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - if self.text_encoder_2 is not None: - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) - else: - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) - if do_classifier_free_guidance: - negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) - - if self.text_encoder is not None: - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) - - if self.text_encoder_2 is not None: - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder_2, lora_scale) - - return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - def check_inputs( - self, - prompt, - prompt_2, - image, - callback_steps, - negative_prompt=None, - negative_prompt_2=None, - prompt_embeds=None, - negative_prompt_embeds=None, - pooled_prompt_embeds=None, - negative_pooled_prompt_embeds=None, - controlnet_conditioning_scale=1.0, - control_guidance_start=0.0, - control_guidance_end=1.0, - callback_on_step_end_tensor_inputs=None, - ): - if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs - ): - raise ValueError( - f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt_2 is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): - raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - elif negative_prompt_2 is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - if prompt_embeds is not None and pooled_prompt_embeds is None: - raise ValueError( - "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." - ) - - if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: - raise ValueError( - "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." - ) - - # Check `image` and `controlnet_conditioning_scale` - is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( - self.unet, torch._dynamo.eval_frame.OptimizedModule - ) - if ( - isinstance(self.controlnet, ControlNetLoRAModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, ControlNetLoRAModel) - ): - self.check_image(image, prompt, prompt_embeds) - if not isinstance(controlnet_conditioning_scale, float): - raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") - else: - assert False - - for start, end in zip(control_guidance_start, control_guidance_end): - if start >= end: - raise ValueError( - f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." - ) - if start < 0.0: - raise ValueError(f"control guidance start: {start} can't be smaller than 0.") - if end > 1.0: - raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") - - # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image - def check_image(self, image, prompt, prompt_embeds): - image_is_pil = isinstance(image, PIL.Image.Image) - image_is_tensor = isinstance(image, torch.Tensor) - image_is_np = isinstance(image, np.ndarray) - image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) - image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) - image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) - - if ( - not image_is_pil - and not image_is_tensor - and not image_is_np - and not image_is_pil_list - and not image_is_tensor_list - and not image_is_np_list - ): - raise TypeError( - f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" - ) - - if image_is_pil: - image_batch_size = 1 - else: - image_batch_size = len(image) - - if prompt is not None and isinstance(prompt, str): - prompt_batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - prompt_batch_size = len(prompt) - elif prompt_embeds is not None: - prompt_batch_size = prompt_embeds.shape[0] - - if image_batch_size != 1 and image_batch_size != prompt_batch_size: - raise ValueError( - f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" - ) - - # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image - def prepare_image( - self, - image, - width, - height, - batch_size, - num_images_per_prompt, - device, - dtype, - do_classifier_free_guidance=False, - guess_mode=False, - ): - image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) - image_batch_size = image.shape[0] - - if image_batch_size == 1: - repeat_by = batch_size - else: - # image batch size is the same as prompt batch size - repeat_by = num_images_per_prompt - - image = image.repeat_interleave(repeat_by, dim=0) - - image = image.to(device=device, dtype=dtype) - - if do_classifier_free_guidance and not guess_mode: - image = torch.cat([image] * 2) - - return image - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = ( - batch_size, - num_channels_latents, - int(height) // self.vae_scale_factor, - int(width) // self.vae_scale_factor, - ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids - def _get_add_time_ids( - self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None - ): - add_time_ids = list(original_size + crops_coords_top_left + target_size) - - passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim - ) - expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features - - if expected_add_embed_dim != passed_add_embed_dim: - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." - ) - - add_time_ids = torch.tensor([add_time_ids], dtype=dtype) - return add_time_ids - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae - def upcast_vae(self): - dtype = self.vae.dtype - self.vae.to(dtype=torch.float32) - use_torch_2_0_or_xformers = isinstance( - self.vae.decoder.mid_block.attentions[0].processor, - ( - AttnProcessor2_0, - XFormersAttnProcessor, - ), - ) - # if xformers or torch_2_0 is used attention block does not need - # to be in float32 which can save lots of memory - if use_torch_2_0_or_xformers: - self.vae.post_quant_conv.to(dtype) - self.vae.decoder.conv_in.to(dtype) - self.vae.decoder.mid_block.to(dtype) - - # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding - def get_guidance_scale_embedding( - self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 - ) -> torch.Tensor: - """ - See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 - - Args: - w (`torch.Tensor`): - Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. - embedding_dim (`int`, *optional*, defaults to 512): - Dimension of the embeddings to generate. - dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): - Data type of the generated embeddings. - - Returns: - `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. - """ - assert len(w.shape) == 1 - w = w * 1000.0 - - half_dim = embedding_dim // 2 - emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) - emb = w.to(dtype)[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0, 1)) - assert emb.shape == (w.shape[0], embedding_dim) - return emb - - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def clip_skip(self): - return self._clip_skip - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - @property - def do_classifier_free_guidance(self): - return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None - - @property - def cross_attention_kwargs(self): - return self._cross_attention_kwargs - - @property - def denoising_end(self): - return self._denoising_end - - @property - def num_timesteps(self): - return self._num_timesteps - - @property - def interrupt(self): - return self._interrupt - - @torch.no_grad() - # @replace_example_docstring(EXAMPLE_DOC_STRING) - def __call__( - self, - prompt: Union[str, List[str]] = None, - prompt_2: Optional[Union[str, List[str]]] = None, - image: PipelineImageInput = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 50, - timesteps: List[int] = None, - sigmas: List[float] = None, - denoising_end: Optional[float] = None, - guidance_scale: float = 5.0, - negative_prompt: Optional[Union[str, List[str]]] = None, - negative_prompt_2: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.Tensor] = None, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - pooled_prompt_embeds: Optional[torch.Tensor] = None, - negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - controlnet_conditioning_scale: float = 1.0, - guess_mode: bool = False, - control_guidance_start: Union[float, List[float]] = 0.0, - control_guidance_end: Union[float, List[float]] = 1.0, - original_size: Tuple[int, int] = None, - crops_coords_top_left: Tuple[int, int] = (0, 0), - target_size: Tuple[int, int] = None, - negative_original_size: Optional[Tuple[int, int]] = None, - negative_crops_coords_top_left: Tuple[int, int] = (0, 0), - negative_target_size: Optional[Tuple[int, int]] = None, - clip_skip: Optional[int] = None, - callback_on_step_end: Optional[ - Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] - ] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], - **kwargs, - ): - r""" - """ - - callback = kwargs.pop("callback", None) - callback_steps = kwargs.pop("callback_steps", None) - - if callback is not None: - deprecate( - "callback", - "1.0.0", - "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", - ) - if callback_steps is not None: - deprecate( - "callback_steps", - "1.0.0", - "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", - ) - - if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): - callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs - - controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet - - # align format for control guidance - if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): - control_guidance_start = len(control_guidance_end) * [control_guidance_start] - elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): - control_guidance_end = len(control_guidance_start) * [control_guidance_end] - elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): - control_guidance_start, control_guidance_end = ( - [control_guidance_start], - [control_guidance_end], - ) - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt, - prompt_2, - image, - callback_steps, - negative_prompt, - negative_prompt_2, - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - controlnet_conditioning_scale, - control_guidance_start, - control_guidance_end, - callback_on_step_end_tensor_inputs, - ) - - self._guidance_scale = guidance_scale - self._clip_skip = clip_skip - self._cross_attention_kwargs = cross_attention_kwargs - self._denoising_end = denoising_end - self._interrupt = False - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - device = self._execution_device - - global_pool_conditions = controlnet.config.global_pool_conditions - guess_mode = guess_mode or global_pool_conditions - - # 3. Encode input prompt - text_encoder_lora_scale = ( - self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None - ) - ( - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) = self.encode_prompt( - prompt, - prompt_2, - device, - num_images_per_prompt, - self.do_classifier_free_guidance, - negative_prompt, - negative_prompt_2, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - lora_scale=text_encoder_lora_scale, - clip_skip=self.clip_skip, - ) - - # 4. Prepare image - if isinstance(controlnet, ControlNetLoRAModel): - image = self.prepare_image( - image=image, - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=controlnet.dtype, - do_classifier_free_guidance=self.do_classifier_free_guidance, - guess_mode=guess_mode, - ) - height, width = image.shape[-2:] - else: - assert False - - # 5. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas - ) - self._num_timesteps = len(timesteps) - - # 6. Prepare latent variables - num_channels_latents = self.unet.config.in_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) - - # 6.5 Optionally get Guidance Scale Embedding - timestep_cond = None - if self.unet.config.time_cond_proj_dim is not None: - guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) - timestep_cond = self.get_guidance_scale_embedding( - guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim - ).to(device=device, dtype=latents.dtype) - - # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 7.1 Create tensor stating which controlnets to keep - controlnet_keep = [] - for i in range(len(timesteps)): - keeps = [ - 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) - for s, e in zip(control_guidance_start, control_guidance_end) - ] - controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetLoRAModel) else keeps) - - # 7.2 Prepare added time ids & embeddings - if isinstance(image, list): - original_size = original_size or image[0].shape[-2:] - else: - original_size = original_size or image.shape[-2:] - target_size = target_size or (height, width) - - add_text_embeds = pooled_prompt_embeds - if self.text_encoder_2 is None: - text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) - else: - text_encoder_projection_dim = self.text_encoder_2.config.projection_dim - - add_time_ids = self._get_add_time_ids( - original_size, - crops_coords_top_left, - target_size, - dtype=prompt_embeds.dtype, - text_encoder_projection_dim=text_encoder_projection_dim, - ) - - if negative_original_size is not None and negative_target_size is not None: - negative_add_time_ids = self._get_add_time_ids( - negative_original_size, - negative_crops_coords_top_left, - negative_target_size, - dtype=prompt_embeds.dtype, - text_encoder_projection_dim=text_encoder_projection_dim, - ) - else: - negative_add_time_ids = add_time_ids - - if self.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) - add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) - - prompt_embeds = prompt_embeds.to(device) - add_text_embeds = add_text_embeds.to(device) - add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) - - # 8. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - - # 8.1 Apply denoising_end - if ( - self.denoising_end is not None - and isinstance(self.denoising_end, float) - and self.denoising_end > 0 - and self.denoising_end < 1 - ): - discrete_timestep_cutoff = int( - round( - self.scheduler.config.num_train_timesteps - - (self.denoising_end * self.scheduler.config.num_train_timesteps) - ) - ) - num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) - timesteps = timesteps[:num_inference_steps] - - is_unet_compiled = is_compiled_module(self.unet) - is_controlnet_compiled = is_compiled_module(self.controlnet) - is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - if self.interrupt: - continue - - # Relevant thread: - # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 - if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: - torch._inductor.cudagraph_mark_step_begin() - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} - - # control-lora inference - if guess_mode and self.do_classifier_free_guidance: - # Infer ControlNet only for the conditional batch. - control_model_input = latents - control_model_input = self.scheduler.scale_model_input(control_model_input, t) - controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] - controlnet_added_cond_kwargs = { - "text_embeds": add_text_embeds.chunk(2)[1], - "time_ids": add_time_ids.chunk(2)[1], - } - else: - control_model_input = latent_model_input - controlnet_prompt_embeds = prompt_embeds - controlnet_added_cond_kwargs = added_cond_kwargs - - if isinstance(controlnet_keep[i], list): - cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] - else: - controlnet_cond_scale = controlnet_conditioning_scale - if isinstance(controlnet_cond_scale, list): - controlnet_cond_scale = controlnet_cond_scale[0] - cond_scale = controlnet_cond_scale * controlnet_keep[i] - - down_block_res_samples, mid_block_res_sample = self.controlnet( - control_model_input, - t, - encoder_hidden_states=controlnet_prompt_embeds, - controlnet_cond=image, - conditioning_scale=cond_scale, - guess_mode=guess_mode, - added_cond_kwargs=controlnet_added_cond_kwargs, - return_dict=False, - ) - - if guess_mode and self.do_classifier_free_guidance: - # Inferred ControlNet only for the conditional batch. - # To apply the output of ControlNet to both the unconditional and conditional batches, - # add 0 to the unconditional batch to keep it unchanged. - down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] - mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) - - # predict the noise residual - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - timestep_cond=timestep_cond, - cross_attention_kwargs=self.cross_attention_kwargs, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, - added_cond_kwargs=added_cond_kwargs, - return_dict=False, - )[0] - - # perform guidance - if self.do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) - negative_pooled_prompt_embeds = callback_outputs.pop( - "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds - ) - add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) - negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) - image = callback_outputs.pop("image", image) - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - step_idx = i // getattr(self.scheduler, "order", 1) - callback(step_idx, t, latents) - - if XLA_AVAILABLE: - xm.mark_step() - - if not output_type == "latent": - # make sure the VAE is in float32 mode, as it overflows in float16 - needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast - - if needs_upcasting: - self.upcast_vae() - latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) - - # unscale/denormalize the latents - # denormalize with the mean and std if available and not None - has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None - has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None - if has_latents_mean and has_latents_std: - latents_mean = ( - torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) - ) - latents_std = ( - torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) - ) - latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean - else: - latents = latents / self.vae.config.scaling_factor - - image = self.vae.decode(latents, return_dict=False)[0] - - # cast back to fp16 if needed - if needs_upcasting: - self.vae.to(dtype=torch.float16) - else: - image = latents - - if not output_type == "latent": - # apply watermark if available - if self.watermark is not None: - image = self.watermark.apply_watermark(image) - - image = self.image_processor.postprocess(image, output_type=output_type) - - # Offload all models - self.maybe_free_model_hooks() - - if not return_dict: - return (image,) - - return StableDiffusionXLPipelineOutput(images=image) - - -if __name__ == "__main__": - from diffusers import ( - StableDiffusionXLControlNetPipeline, - ControlNetModel, - UNet2DConditionModel, - ) - import torch - - pipe_id = "stabilityai/stable-diffusion-xl-base-1.0" - controlnet_id = "xinsir/controlnet-canny-sdxl-1.0" - lora_id = "stabilityai/control-lora" - lora_filename = "control-LoRAs-rank128/control-lora-canny-rank128.safetensors" - - - unet = UNet2DConditionModel.from_pretrained(pipe_id, subfolder="unet", torch_dtype=torch.float16).to("cuda") - controlnet = ControlNetLoRAModel.from_unet_and_single_file(unet, pretrained_model_link_or_path_or_dict="https://huggingface.co/stabilityai/control-lora/control-LoRAs-rank128/control-lora-canny-rank128.safetensors").to("cuda").to(torch.float16) - - from diffusers import AutoencoderKL - from diffusers.utils import load_image, make_image_grid - from PIL import Image - import numpy as np - import cv2 - - prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" - negative_prompt = "low quality, bad quality, sketches" - - image = load_image("https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png") - - controlnet_conditioning_scale = 1.0 # recommended for good generalization - - vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) - pipe = StableDiffusionXLControlLoRAPipeline.from_pretrained( - pipe_id, - unet=unet, - controlnet=controlnet, - vae=vae, - torch_dtype=torch.float16, - safety_checker=None, - ).to("cuda") - - image = np.array(image) - image = cv2.Canny(image, 100, 200) - image = image[:, :, None] - image = np.concatenate([image, image, image], axis=2) - image = Image.fromarray(image) - - images = pipe( - prompt, negative_prompt=negative_prompt, image=image, - controlnet_conditioning_scale=controlnet_conditioning_scale, - num_images_per_prompt=4 - ).images - - final_image = [image] + images - grid = make_image_grid(final_image, 1, 5) - grid.save(f"hf-logo1.png") From 7c25a065913731e3068440195f27f2bf9aa2bf22 Mon Sep 17 00:00:00 2001 From: lavinal712 Date: Sun, 23 Mar 2025 09:36:00 +0000 Subject: [PATCH 09/19] fix PeftAdapterMixin --- src/diffusers/loaders/__init__.py | 5 +- src/diffusers/loaders/peft.py | 215 +++--------------- .../models/controlnets/controlnet.py | 4 +- 3 files changed, 38 insertions(+), 186 deletions(-) diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index b0cd85dff916..3ba1bfacf3dd 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -84,7 +84,7 @@ def text_encoder_attn_modules(text_encoder): "SD3IPAdapterMixin", ] -_import_structure["peft"] = ["PeftAdapterMixin", "ControlLoRAMixin"] +_import_structure["peft"] = ["PeftAdapterMixin"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -94,7 +94,6 @@ def text_encoder_attn_modules(text_encoder): from .transformer_sd3 import SD3Transformer2DLoadersMixin from .unet import UNet2DConditionLoadersMixin from .utils import AttnProcsLayers - from .peft import ControlLoRAMixin if is_transformers_available(): from .ip_adapter import ( @@ -121,7 +120,7 @@ def text_encoder_attn_modules(text_encoder): from .single_file import FromSingleFileMixin from .textual_inversion import TextualInversionLoaderMixin - from .peft import PeftAdapterMixin, ControlLoRAMixin + from .peft import PeftAdapterMixin else: import sys diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 91e931f44e1b..af705575849f 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -114,6 +114,31 @@ def _maybe_adjust_config(config): return config +def _maybe_adjust_config_for_control_lora(config): + """ + """ + + target_modules_before = config["target_modules"] + target_modules = [] + modules_to_save = [] + + for module in target_modules_before: + if "base_layer" in module: + continue + elif "modules_to_save" in module: + base_name = module.split(".modules_to_save.", 1)[0] + modules_to_save.append(base_name) + else: + base_name = ".".join(module.split(".")[:-1]) + if base_name and base_name not in modules_to_save: + target_modules.append(module) + + config["target_modules"] = target_modules + config["modules_to_save"] = modules_to_save + + return config + + class PeftAdapterMixin: """ A class containing all functions for loading and using adapters weights that are supported in PEFT library. For @@ -245,6 +270,13 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans f"Adapter name {adapter_name} already in use in the model - please select a new adapter name." ) + # Control LoRA from SAI is different from BFL Control LoRA + # https://huggingface.co/stabilityai/control-lora/ + is_control_lora = "lora_controlnet" in state_dict + if is_control_lora: + del state_dict["lora_controlnet"] + state_dict = convert_control_lora_state_dict_to_peft(state_dict) + # check with first key if is not in peft format first_key = next(iter(state_dict.keys())) if "lora_A" not in first_key: @@ -265,6 +297,8 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) # TODO: revisit this after https://github.com/huggingface/peft/pull/2382 is merged. lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs) + if is_control_lora: + lora_config_kwargs = _maybe_adjust_config_for_control_lora(lora_config_kwargs) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: @@ -767,184 +801,3 @@ def delete_adapters(self, adapter_names: Union[List[str], str]): # Pop also the corresponding adapter from the config if hasattr(self, "peft_config"): self.peft_config.pop(adapter_name, None) - - -class ControlLoRAMixin(PeftAdapterMixin): - TARGET_MODULES = ["to_q", "to_k", "to_v", "to_out.0", "ff.net.0.proj", "ff.net.2", "proj_in", "proj_out", - "conv", "conv1", "conv2", "conv_in", "conv_shortcut", "linear_1", "linear_2", "time_emb_proj"] - SAVE_MODULES = ["controlnet_cond_embedding.conv_in", "controlnet_cond_embedding.blocks.0", - "controlnet_cond_embedding.blocks.1", "controlnet_cond_embedding.blocks.2", - "controlnet_cond_embedding.blocks.3", "controlnet_cond_embedding.blocks.4", - "controlnet_cond_embedding.blocks.5", "controlnet_cond_embedding.conv_out", - "controlnet_down_blocks.0", "controlnet_down_blocks.1", "controlnet_down_blocks.2", - "controlnet_down_blocks.3", "controlnet_down_blocks.4", "controlnet_down_blocks.5", - "controlnet_down_blocks.6", "controlnet_down_blocks.7", "controlnet_down_blocks.8", - "controlnet_mid_block", "norm", "norm1", "norm2", "norm3"] - - def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="transformer", **kwargs): - from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict - from peft.tuners.tuners_utils import BaseTunerLayer - - cache_dir = kwargs.pop("cache_dir", None) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", None) - token = kwargs.pop("token", None) - revision = kwargs.pop("revision", None) - subfolder = kwargs.pop("subfolder", None) - weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) - adapter_name = kwargs.pop("adapter_name", None) - network_alphas = kwargs.pop("network_alphas", None) - _pipeline = kwargs.pop("_pipeline", None) - low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False) - allow_pickle = False - - if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } - - state_dict = _fetch_state_dict( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - use_safetensors=use_safetensors, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - allow_pickle=allow_pickle, - ) - if network_alphas is not None and prefix is None: - raise ValueError("`network_alphas` cannot be None when `prefix` is None.") - - if prefix is not None: - keys = list(state_dict.keys()) - model_keys = [k for k in keys if k.startswith(f"{prefix}.")] - if len(model_keys) > 0: - state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in model_keys} - - if len(state_dict) > 0: - if adapter_name in getattr(self, "peft_config", {}): - raise ValueError( - f"Adapter name {adapter_name} already in use in the model - please select a new adapter name." - ) - - # check with first key if is not in peft format - if "lora_controlnet" in state_dict: - del state_dict["lora_controlnet"] - state_dict = convert_control_lora_state_dict_to_peft(state_dict) - - rank = {} - for key, val in state_dict.items(): - # Cannot figure out rank from lora layers that don't have atleast 2 dimensions. - # Bias layers in LoRA only have a single dimension - if "lora_B" in key and val.ndim > 1: - rank[key] = val.shape[1] - - if network_alphas is not None and len(network_alphas) >= 1: - alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")] - network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} - - lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) - lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs) - - if "use_dora" in lora_config_kwargs: - if lora_config_kwargs["use_dora"]: - if is_peft_version("<", "0.9.0"): - raise ValueError( - "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." - ) - else: - if is_peft_version("<", "0.9.0"): - lora_config_kwargs.pop("use_dora") - - if "lora_bias" in lora_config_kwargs: - if lora_config_kwargs["lora_bias"]: - if is_peft_version("<=", "0.13.2"): - raise ValueError( - "You need `peft` 0.14.0 at least to use `lora_bias` in LoRAs. Please upgrade your installation of `peft`." - ) - else: - if is_peft_version("<=", "0.13.2"): - lora_config_kwargs.pop("lora_bias") - - lora_config_kwargs["bias"] = "all" - lora_config_kwargs["target_modules"] = self.TARGET_MODULES - lora_config_kwargs["modules_to_save"] = self.SAVE_MODULES - lora_config = LoraConfig(**lora_config_kwargs) - # adapter_name - if adapter_name is None: - adapter_name = "default" - - # =", "0.13.1"): - peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage - - # To handle scenarios where we cannot successfully set state dict. If it's unsucessful, - # we should also delete the `peft_config` associated to the `adapter_name`. - try: - inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs) - incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs) - except Exception as e: - # In case `inject_adapter_in_model()` was unsuccessful even before injecting the `peft_config`. - if hasattr(self, "peft_config"): - for module in self.modules(): - if isinstance(module, BaseTunerLayer): - active_adapters = module.active_adapters - for active_adapter in active_adapters: - if adapter_name in active_adapter: - module.delete_adapter(adapter_name) - - self.peft_config.pop(adapter_name) - logger.error(f"Loading {adapter_name} was unsucessful with the following error: \n{e}") - raise - - warn_msg = "" - if incompatible_keys is not None: - # Check only for unexpected keys. - unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) - if unexpected_keys: - lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k] - if lora_unexpected_keys: - warn_msg = ( - f"Loading adapter weights from state_dict led to unexpected keys found in the model:" - f" {', '.join(lora_unexpected_keys)}. " - ) - - # Filter missing keys specific to the current adapter. - missing_keys = getattr(incompatible_keys, "missing_keys", None) - if missing_keys: - lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k] - if lora_missing_keys: - warn_msg += ( - f"Loading adapter weights from state_dict led to missing keys in the model:" - f" {', '.join(lora_missing_keys)}." - ) - - if warn_msg: - logger.warning(warn_msg) - - # Offload back. - if is_model_cpu_offload: - _pipeline.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - _pipeline.enable_sequential_cpu_offload() - # Unsafe code /> diff --git a/src/diffusers/models/controlnets/controlnet.py b/src/diffusers/models/controlnets/controlnet.py index c1404c48c3cb..e49556c035d6 100644 --- a/src/diffusers/models/controlnets/controlnet.py +++ b/src/diffusers/models/controlnets/controlnet.py @@ -19,7 +19,7 @@ from torch.nn import functional as F from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import PeftAdapterMixin, ControlLoRAMixin +from ...loaders import PeftAdapterMixin from ...loaders.single_file_model import FromOriginalModelMixin from ...utils import BaseOutput, logging from ..attention_processor import ( @@ -107,7 +107,7 @@ def forward(self, conditioning): return embedding -class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, ControlLoRAMixin): +class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin): """ A ControlNet model. From 0719c20f5e05367c47f190d7b77d15754556bb2c Mon Sep 17 00:00:00 2001 From: lavinal712 Date: Sun, 23 Mar 2025 10:27:40 +0000 Subject: [PATCH 10/19] fix module_to_save bug --- src/diffusers/loaders/peft.py | 23 ++++++++++++++--------- src/diffusers/utils/state_dict_utils.py | 21 --------------------- 2 files changed, 14 insertions(+), 30 deletions(-) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index af705575849f..35b9c98a5ba3 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -123,18 +123,20 @@ def _maybe_adjust_config_for_control_lora(config): modules_to_save = [] for module in target_modules_before: - if "base_layer" in module: - continue - elif "modules_to_save" in module: - base_name = module.split(".modules_to_save.", 1)[0] + if module.endswith("weight"): + base_name = ".".join(module.split(".")[:-1]) modules_to_save.append(base_name) - else: + elif module.endswith("bias"): base_name = ".".join(module.split(".")[:-1]) - if base_name and base_name not in modules_to_save: - target_modules.append(module) + if ".".join([base_name, "weight"]) in target_modules_before: + modules_to_save.append(base_name) + else: + target_modules.append(base_name) + else: + target_modules.append(module) - config["target_modules"] = target_modules - config["modules_to_save"] = modules_to_save + config["target_modules"] = list(set(target_modules)) + config["modules_to_save"] = list(set(modules_to_save)) return config @@ -299,6 +301,9 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs) if is_control_lora: lora_config_kwargs = _maybe_adjust_config_for_control_lora(lora_config_kwargs) + import json + with open("lora_config_kwargs.json", "w") as f: + json.dump(lora_config_kwargs, f, indent=2) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py index 322b118a6517..da7e10c7e9c1 100644 --- a/src/diffusers/utils/state_dict_utils.py +++ b/src/diffusers/utils/state_dict_utils.py @@ -51,55 +51,34 @@ class StateDictType(enum.Enum): } CONTROL_LORA_TO_DIFFUSERS = { - ".to_q.bias": ".to_q.base_layer.bias", ".to_q.down": ".to_q.lora_A.weight", ".to_q.up": ".to_q.lora_B.weight", - ".to_k.bias": ".to_k.base_layer.bias", ".to_k.down": ".to_k.lora_A.weight", ".to_k.up": ".to_k.lora_B.weight", - ".to_v.bias": ".to_v.base_layer.bias", ".to_v.down": ".to_v.lora_A.weight", ".to_v.up": ".to_v.lora_B.weight", - ".to_out.0.bias": ".to_out.0.base_layer.bias", ".to_out.0.down": ".to_out.0.lora_A.weight", ".to_out.0.up": ".to_out.0.lora_B.weight", - ".ff.net.0.proj.bias": ".ff.net.0.proj.base_layer.bias", ".ff.net.0.proj.down": ".ff.net.0.proj.lora_A.weight", ".ff.net.0.proj.up": ".ff.net.0.proj.lora_B.weight", - ".ff.net.2.bias": ".ff.net.2.base_layer.bias", ".ff.net.2.down": ".ff.net.2.lora_A.weight", ".ff.net.2.up": ".ff.net.2.lora_B.weight", - ".proj_in.bias": ".proj_in.base_layer.bias", ".proj_in.down": ".proj_in.lora_A.weight", ".proj_in.up": ".proj_in.lora_B.weight", - ".proj_out.bias": ".proj_out.base_layer.bias", ".proj_out.down": ".proj_out.lora_A.weight", ".proj_out.up": ".proj_out.lora_B.weight", - ".conv.bias": ".conv.base_layer.bias", ".conv.down": ".conv.lora_A.weight", ".conv.up": ".conv.lora_B.weight", - **{f".conv{i}.bias": f".conv{i}.base_layer.bias" for i in range(1, 3)}, **{f".conv{i}.down": f".conv{i}.lora_A.weight" for i in range(1, 3)}, **{f".conv{i}.up": f".conv{i}.lora_B.weight" for i in range(1, 3)}, - "conv_in.bias": "conv_in.base_layer.bias", "conv_in.down": "conv_in.lora_A.weight", "conv_in.up": "conv_in.lora_B.weight", - ".conv_shortcut.bias": ".conv_shortcut.base_layer.bias", ".conv_shortcut.down": ".conv_shortcut.lora_A.weight", ".conv_shortcut.up": ".conv_shortcut.lora_B.weight", - **{f".linear_{i}.bias": f".linear_{i}.base_layer.bias" for i in range(1, 3)}, **{f".linear_{i}.down": f".linear_{i}.lora_A.weight" for i in range(1, 3)}, **{f".linear_{i}.up": f".linear_{i}.lora_B.weight" for i in range(1, 3)}, - "time_emb_proj.bias": "time_emb_proj.base_layer.bias", "time_emb_proj.down": "time_emb_proj.lora_A.weight", "time_emb_proj.up": "time_emb_proj.lora_B.weight", - "controlnet_cond_embedding.conv_in.bias": "controlnet_cond_embedding.conv_in.modules_to_save.bias", - "controlnet_cond_embedding.conv_out.bias": "controlnet_cond_embedding.conv_out.modules_to_save.bias", - **{f"controlnet_cond_embedding.blocks.{i}.bias": f"controlnet_cond_embedding.blocks.{i}.modules_to_save.bias" for i in range(6)}, - **{f"controlnet_down_blocks.{i}.bias": f"controlnet_down_blocks.{i}.modules_to_save.bias" for i in range(9)}, - "controlnet_mid_block.bias": "controlnet_mid_block.modules_to_save.bias", - ".norm.bias": ".norm.modules_to_save.bias", - **{f".norm{i}.bias": f".norm{i}.modules_to_save.bias" for i in range(1, 4)}, } DIFFUSERS_TO_PEFT = { From 81eed41b748e0fedc6448d6022e867ccd492deb6 Mon Sep 17 00:00:00 2001 From: lavinal712 Date: Sun, 23 Mar 2025 10:29:08 +0000 Subject: [PATCH 11/19] delete json print --- src/diffusers/loaders/peft.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 35b9c98a5ba3..3aa06ffb8ef0 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -301,9 +301,6 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs) if is_control_lora: lora_config_kwargs = _maybe_adjust_config_for_control_lora(lora_config_kwargs) - import json - with open("lora_config_kwargs.json", "w") as f: - json.dump(lora_config_kwargs, f, indent=2) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: From 6a1ff82d0830484202cf3166229add29e4626609 Mon Sep 17 00:00:00 2001 From: lavinal712 Date: Wed, 9 Apr 2025 07:41:14 +0000 Subject: [PATCH 12/19] resolve conflits --- src/diffusers/loaders/peft.py | 227 ++++++++++++++++++++------------ src/diffusers/utils/__init__.py | 2 +- 2 files changed, 142 insertions(+), 87 deletions(-) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 3aa06ffb8ef0..9165c46f3c78 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -16,7 +16,7 @@ import os from functools import partial from pathlib import Path -from typing import Dict, List, Optional, Union +from typing import Dict, List, Literal, Optional, Union import safetensors import torch @@ -25,7 +25,6 @@ MIN_PEFT_VERSION, USE_PEFT_BACKEND, check_peft_version, - convert_control_lora_state_dict_to_peft, convert_unet_state_dict_to_peft, delete_adapter_layers, get_adapter_name, @@ -59,23 +58,11 @@ } -def _maybe_adjust_config(config): - """ - We may run into some ambiguous configuration values when a model has module names, sharing a common prefix - (`proj_out.weight` and `blocks.transformer.proj_out.weight`, for example) and they have different LoRA ranks. This - method removes the ambiguity by following what is described here: - https://github.com/huggingface/diffusers/pull/9985#issuecomment-2493840028. - """ - # Track keys that have been explicitly removed to prevent re-adding them. - deleted_keys = set() - +def _maybe_raise_error_for_ambiguity(config): rank_pattern = config["rank_pattern"].copy() target_modules = config["target_modules"] - original_r = config["r"] for key in list(rank_pattern.keys()): - key_rank = rank_pattern[key] - # try to detect ambiguity # `target_modules` can also be a str, in which case this loop would loop # over the chars of the str. The technically correct way to match LoRA keys @@ -83,62 +70,12 @@ def _maybe_adjust_config(config): # But this cuts it for now. exact_matches = [mod for mod in target_modules if mod == key] substring_matches = [mod for mod in target_modules if key in mod and mod != key] - ambiguous_key = key if exact_matches and substring_matches: - # if ambiguous, update the rank associated with the ambiguous key (`proj_out`, for example) - config["r"] = key_rank - # remove the ambiguous key from `rank_pattern` and record it as deleted - del config["rank_pattern"][key] - deleted_keys.add(key) - # For substring matches, add them with the original rank only if they haven't been assigned already - for mod in substring_matches: - if mod not in config["rank_pattern"] and mod not in deleted_keys: - config["rank_pattern"][mod] = original_r - - # Update the rest of the target modules with the original rank if not already set and not deleted - for mod in target_modules: - if mod != ambiguous_key and mod not in config["rank_pattern"] and mod not in deleted_keys: - config["rank_pattern"][mod] = original_r - - # Handle alphas to deal with cases like: - # https://github.com/huggingface/diffusers/pull/9999#issuecomment-2516180777 - has_different_ranks = len(config["rank_pattern"]) > 1 and list(config["rank_pattern"])[0] != config["r"] - if has_different_ranks: - config["lora_alpha"] = config["r"] - alpha_pattern = {} - for module_name, rank in config["rank_pattern"].items(): - alpha_pattern[module_name] = rank - config["alpha_pattern"] = alpha_pattern - - return config - - -def _maybe_adjust_config_for_control_lora(config): - """ - """ - - target_modules_before = config["target_modules"] - target_modules = [] - modules_to_save = [] - - for module in target_modules_before: - if module.endswith("weight"): - base_name = ".".join(module.split(".")[:-1]) - modules_to_save.append(base_name) - elif module.endswith("bias"): - base_name = ".".join(module.split(".")[:-1]) - if ".".join([base_name, "weight"]) in target_modules_before: - modules_to_save.append(base_name) - else: - target_modules.append(base_name) - else: - target_modules.append(module) - - config["target_modules"] = list(set(target_modules)) - config["modules_to_save"] = list(set(modules_to_save)) - - return config + if is_peft_version("<", "0.14.1"): + raise ValueError( + "There are ambiguous keys present in this LoRA. To load it, please update your `peft` installation - `pip install -U peft`." + ) class PeftAdapterMixin: @@ -156,6 +93,8 @@ class PeftAdapterMixin: """ _hf_peft_config_loaded = False + # kwargs for prepare_model_for_compiled_hotswap, if required + _prepare_lora_hotswap_kwargs: Optional[dict] = None @classmethod # Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading @@ -173,7 +112,9 @@ def _optionally_disable_offloading(cls, _pipeline): """ return _func_optionally_disable_offloading(_pipeline=_pipeline) - def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="transformer", **kwargs): + def load_lora_adapter( + self, pretrained_model_name_or_path_or_dict, prefix="transformer", hotswap: bool = False, **kwargs + ): r""" Loads a LoRA adapter into the underlying model. @@ -217,6 +158,29 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict from peft.tuners.tuners_utils import BaseTunerLayer @@ -267,17 +231,15 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} if len(state_dict) > 0: - if adapter_name in getattr(self, "peft_config", {}): + if adapter_name in getattr(self, "peft_config", {}) and not hotswap: raise ValueError( f"Adapter name {adapter_name} already in use in the model - please select a new adapter name." ) - - # Control LoRA from SAI is different from BFL Control LoRA - # https://huggingface.co/stabilityai/control-lora/ - is_control_lora = "lora_controlnet" in state_dict - if is_control_lora: - del state_dict["lora_controlnet"] - state_dict = convert_control_lora_state_dict_to_peft(state_dict) + elif adapter_name not in getattr(self, "peft_config", {}) and hotswap: + raise ValueError( + f"Trying to hotswap LoRA adapter '{adapter_name}' but there is no existing adapter by that name. " + "Please choose an existing adapter name or set `hotswap=False` to prevent hotswapping." + ) # check with first key if is not in peft format first_key = next(iter(state_dict.keys())) @@ -289,18 +251,18 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans # Cannot figure out rank from lora layers that don't have atleast 2 dimensions. # Bias layers in LoRA only have a single dimension if "lora_B" in key and val.ndim > 1: - # TODO: revisit this after https://github.com/huggingface/peft/pull/2382 is merged. - rank[key] = val.shape[1] + # Check out https://github.com/huggingface/peft/pull/2419 for the `^` symbol. + # We may run into some ambiguous configuration values when a model has module + # names, sharing a common prefix (`proj_out.weight` and `blocks.transformer.proj_out.weight`, + # for example) and they have different LoRA ranks. + rank[f"^{key}"] = val.shape[1] if network_alphas is not None and len(network_alphas) >= 1: alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")] network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) - # TODO: revisit this after https://github.com/huggingface/peft/pull/2382 is merged. - lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs) - if is_control_lora: - lora_config_kwargs = _maybe_adjust_config_for_control_lora(lora_config_kwargs) + _maybe_raise_error_for_ambiguity(lora_config_kwargs) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: @@ -339,11 +301,71 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans if is_peft_version(">=", "0.13.1"): peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage + if hotswap or (self._prepare_lora_hotswap_kwargs is not None): + if is_peft_version(">", "0.14.0"): + from peft.utils.hotswap import ( + check_hotswap_configs_compatible, + hotswap_adapter_from_state_dict, + prepare_model_for_compiled_hotswap, + ) + else: + msg = ( + "Hotswapping requires PEFT > v0.14. Please upgrade PEFT to a higher version or install it " + "from source." + ) + raise ImportError(msg) + + if hotswap: + + def map_state_dict_for_hotswap(sd): + # For hotswapping, we need the adapter name to be present in the state dict keys + new_sd = {} + for k, v in sd.items(): + if k.endswith("lora_A.weight") or key.endswith("lora_B.weight"): + k = k[: -len(".weight")] + f".{adapter_name}.weight" + elif k.endswith("lora_B.bias"): # lora_bias=True option + k = k[: -len(".bias")] + f".{adapter_name}.bias" + new_sd[k] = v + return new_sd + # To handle scenarios where we cannot successfully set state dict. If it's unsucessful, # we should also delete the `peft_config` associated to the `adapter_name`. try: - inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs) - incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs) + if hotswap: + state_dict = map_state_dict_for_hotswap(state_dict) + check_hotswap_configs_compatible(self.peft_config[adapter_name], lora_config) + try: + hotswap_adapter_from_state_dict( + model=self, + state_dict=state_dict, + adapter_name=adapter_name, + config=lora_config, + ) + except Exception as e: + logger.error(f"Hotswapping {adapter_name} was unsucessful with the following error: \n{e}") + raise + # the hotswap function raises if there are incompatible keys, so if we reach this point we can set + # it to None + incompatible_keys = None + else: + inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs) + incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs) + + if self._prepare_lora_hotswap_kwargs is not None: + # For hotswapping of compiled models or adapters with different ranks. + # If the user called enable_lora_hotswap, we need to ensure it is called: + # - after the first adapter was loaded + # - before the model is compiled and the 2nd adapter is being hotswapped in + # Therefore, it needs to be called here + prepare_model_for_compiled_hotswap( + self, config=lora_config, **self._prepare_lora_hotswap_kwargs + ) + # We only want to call prepare_model_for_compiled_hotswap once + self._prepare_lora_hotswap_kwargs = None + + # Set peft config loaded flag to True if module has been successfully injected and incompatible keys retrieved + if not self._hf_peft_config_loaded: + self._hf_peft_config_loaded = True except Exception as e: # In case `inject_adapter_in_model()` was unsuccessful even before injecting the `peft_config`. if hasattr(self, "peft_config"): @@ -803,3 +825,36 @@ def delete_adapters(self, adapter_names: Union[List[str], str]): # Pop also the corresponding adapter from the config if hasattr(self, "peft_config"): self.peft_config.pop(adapter_name, None) + + def enable_lora_hotswap( + self, target_rank: int = 128, check_compiled: Literal["error", "warn", "ignore"] = "error" + ) -> None: + """Enables the possibility to hotswap LoRA adapters. + + Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of + the loaded adapters differ. + + Args: + target_rank (`int`, *optional*, defaults to `128`): + The highest rank among all the adapters that will be loaded. + + check_compiled (`str`, *optional*, defaults to `"error"`): + How to handle the case when the model is already compiled, which should generally be avoided. The + options are: + - "error" (default): raise an error + - "warn": issue a warning + - "ignore": do nothing + """ + if getattr(self, "peft_config", {}): + if check_compiled == "error": + raise RuntimeError("Call `enable_lora_hotswap` before loading the first adapter.") + elif check_compiled == "warn": + logger.warning( + "It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation." + ) + elif check_compiled != "ignore": + raise ValueError( + f"check_compiles should be one of 'error', 'warn', or 'ignore', got '{check_compiled}' instead." + ) + + self._prepare_lora_hotswap_kwargs = {"target_rank": target_rank, "check_compiled": check_compiled} diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 2c17b7ca75e3..438faa23e595 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -126,7 +126,7 @@ convert_state_dict_to_kohya, convert_state_dict_to_peft, convert_unet_state_dict_to_peft, - convert_control_lora_state_dict_to_peft, + state_dict_all_zero, ) from .typing_utils import _get_detailed_type, _is_valid_type From 6fff794e59a4d09bb5d6848eaebe297cbcb3c0ec Mon Sep 17 00:00:00 2001 From: lavinal712 Date: Wed, 9 Apr 2025 07:56:40 +0000 Subject: [PATCH 13/19] merged but bug --- src/diffusers/loaders/peft.py | 37 +++++++++++++++++++++++++++++++++ src/diffusers/utils/__init__.py | 1 + 2 files changed, 38 insertions(+) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 9165c46f3c78..0280fc23f7b5 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -25,6 +25,7 @@ MIN_PEFT_VERSION, USE_PEFT_BACKEND, check_peft_version, + convert_control_lora_state_dict_to_peft, convert_unet_state_dict_to_peft, delete_adapter_layers, get_adapter_name, @@ -78,6 +79,33 @@ def _maybe_raise_error_for_ambiguity(config): ) +def _maybe_adjust_config_for_control_lora(config): + """ + """ + + target_modules_before = config["target_modules"] + target_modules = [] + modules_to_save = [] + + for module in target_modules_before: + if module.endswith("weight"): + base_name = ".".join(module.split(".")[:-1]) + modules_to_save.append(base_name) + elif module.endswith("bias"): + base_name = ".".join(module.split(".")[:-1]) + if ".".join([base_name, "weight"]) in target_modules_before: + modules_to_save.append(base_name) + else: + target_modules.append(base_name) + else: + target_modules.append(module) + + config["target_modules"] = list(set(target_modules)) + config["modules_to_save"] = list(set(modules_to_save)) + + return config + + class PeftAdapterMixin: """ A class containing all functions for loading and using adapters weights that are supported in PEFT library. For @@ -241,6 +269,13 @@ def load_lora_adapter( "Please choose an existing adapter name or set `hotswap=False` to prevent hotswapping." ) + # Control LoRA from SAI is different from BFL Control LoRA + # https://huggingface.co/stabilityai/control-lora/ + is_control_lora = "lora_controlnet" in state_dict + if is_control_lora: + del state_dict["lora_controlnet"] + state_dict = convert_control_lora_state_dict_to_peft(state_dict) + # check with first key if is not in peft format first_key = next(iter(state_dict.keys())) if "lora_A" not in first_key: @@ -263,6 +298,8 @@ def load_lora_adapter( lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) _maybe_raise_error_for_ambiguity(lora_config_kwargs) + if is_control_lora: + lora_config_kwargs = _maybe_adjust_config_for_control_lora(lora_config_kwargs) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 438faa23e595..777cfec714f6 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -122,6 +122,7 @@ from .remote_utils import remote_decode from .state_dict_utils import ( convert_all_state_dict_to_peft, + convert_control_lora_state_dict_to_peft, convert_state_dict_to_diffusers, convert_state_dict_to_kohya, convert_state_dict_to_peft, From 63bafc88cdf3bd03464e8354d232120106f7f9c4 Mon Sep 17 00:00:00 2001 From: lavinal712 Date: Thu, 29 May 2025 14:23:41 +0000 Subject: [PATCH 14/19] change peft.py --- src/diffusers/loaders/peft.py | 54 +++++++---------------------------- 1 file changed, 11 insertions(+), 43 deletions(-) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 0280fc23f7b5..7a970c5c5153 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -25,7 +25,6 @@ MIN_PEFT_VERSION, USE_PEFT_BACKEND, check_peft_version, - convert_control_lora_state_dict_to_peft, convert_unet_state_dict_to_peft, delete_adapter_layers, get_adapter_name, @@ -53,9 +52,12 @@ "HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights, "LTXVideoTransformer3DModel": lambda model_cls, weights: weights, "SanaTransformer2DModel": lambda model_cls, weights: weights, + "AuraFlowTransformer2DModel": lambda model_cls, weights: weights, "Lumina2Transformer2DModel": lambda model_cls, weights: weights, "WanTransformer3DModel": lambda model_cls, weights: weights, "CogView4Transformer2DModel": lambda model_cls, weights: weights, + "HiDreamImageTransformer2DModel": lambda model_cls, weights: weights, + "HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights, } @@ -79,33 +81,6 @@ def _maybe_raise_error_for_ambiguity(config): ) -def _maybe_adjust_config_for_control_lora(config): - """ - """ - - target_modules_before = config["target_modules"] - target_modules = [] - modules_to_save = [] - - for module in target_modules_before: - if module.endswith("weight"): - base_name = ".".join(module.split(".")[:-1]) - modules_to_save.append(base_name) - elif module.endswith("bias"): - base_name = ".".join(module.split(".")[:-1]) - if ".".join([base_name, "weight"]) in target_modules_before: - modules_to_save.append(base_name) - else: - target_modules.append(base_name) - else: - target_modules.append(module) - - config["target_modules"] = list(set(target_modules)) - config["modules_to_save"] = list(set(modules_to_save)) - - return config - - class PeftAdapterMixin: """ A class containing all functions for loading and using adapters weights that are supported in PEFT library. For @@ -256,7 +231,7 @@ def load_lora_adapter( raise ValueError("`network_alphas` cannot be None when `prefix` is None.") if prefix is not None: - state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} + state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} if len(state_dict) > 0: if adapter_name in getattr(self, "peft_config", {}) and not hotswap: @@ -269,13 +244,6 @@ def load_lora_adapter( "Please choose an existing adapter name or set `hotswap=False` to prevent hotswapping." ) - # Control LoRA from SAI is different from BFL Control LoRA - # https://huggingface.co/stabilityai/control-lora/ - is_control_lora = "lora_controlnet" in state_dict - if is_control_lora: - del state_dict["lora_controlnet"] - state_dict = convert_control_lora_state_dict_to_peft(state_dict) - # check with first key if is not in peft format first_key = next(iter(state_dict.keys())) if "lora_A" not in first_key: @@ -294,12 +262,12 @@ def load_lora_adapter( if network_alphas is not None and len(network_alphas) >= 1: alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")] - network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} + network_alphas = { + k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys + } lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) _maybe_raise_error_for_ambiguity(lora_config_kwargs) - if is_control_lora: - lora_config_kwargs = _maybe_adjust_config_for_control_lora(lora_config_kwargs) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: @@ -365,7 +333,7 @@ def map_state_dict_for_hotswap(sd): new_sd[k] = v return new_sd - # To handle scenarios where we cannot successfully set state dict. If it's unsucessful, + # To handle scenarios where we cannot successfully set state dict. If it's unsuccessful, # we should also delete the `peft_config` associated to the `adapter_name`. try: if hotswap: @@ -379,7 +347,7 @@ def map_state_dict_for_hotswap(sd): config=lora_config, ) except Exception as e: - logger.error(f"Hotswapping {adapter_name} was unsucessful with the following error: \n{e}") + logger.error(f"Hotswapping {adapter_name} was unsuccessful with the following error: \n{e}") raise # the hotswap function raises if there are incompatible keys, so if we reach this point we can set # it to None @@ -414,7 +382,7 @@ def map_state_dict_for_hotswap(sd): module.delete_adapter(adapter_name) self.peft_config.pop(adapter_name) - logger.error(f"Loading {adapter_name} was unsucessful with the following error: \n{e}") + logger.error(f"Loading {adapter_name} was unsuccessful with the following error: \n{e}") raise warn_msg = "" @@ -747,7 +715,7 @@ def _fuse_lora_apply(self, module, adapter_names=None): if self.lora_scale != 1.0: module.scale_layer(self.lora_scale) - # For BC with prevous PEFT versions, we need to check the signature + # For BC with previous PEFT versions, we need to check the signature # of the `merge` method to see if it supports the `adapter_names` argument. supported_merge_kwargs = list(inspect.signature(module.merge).parameters) if "adapter_names" in supported_merge_kwargs: From 0a5bd74931930f0c9ccbf38e12165528514c3b57 Mon Sep 17 00:00:00 2001 From: lavinal712 Date: Sat, 5 Jul 2025 05:12:57 +0000 Subject: [PATCH 15/19] 1 --- src/diffusers/loaders/peft.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 4ade3374d80e..26ce3bc02aed 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -27,6 +27,7 @@ MIN_PEFT_VERSION, USE_PEFT_BACKEND, check_peft_version, + convert_control_lora_state_dict_to_peft, convert_unet_state_dict_to_peft, delete_adapter_layers, get_adapter_name, @@ -227,6 +228,15 @@ def load_lora_adapter( if "lora_A" not in first_key: state_dict = convert_unet_state_dict_to_peft(state_dict) + # Control LoRA from SAI is different from BFL Control LoRA + # https://huggingface.co/stabilityai/control-lora/ + is_control_lora = "lora_controlnet" in state_dict + if is_control_lora: + state_dict = convert_control_lora_state_dict_to_peft(state_dict) + with open("state_dict1.txt", "w") as f: + for key, val in state_dict.items(): + f.write(f"{key}: {val.shape}\n") + rank = {} for key, val in state_dict.items(): # Cannot figure out rank from lora layers that don't have at least 2 dimensions. @@ -257,6 +267,9 @@ def load_lora_adapter( model_state_dict=self.state_dict(), adapter_name=adapter_name, ) + if is_control_lora: + lora_config.modules_to_save = lora_config.exclude_modules + lora_config.exclude_modules = [] # Date: Sat, 5 Jul 2025 07:52:01 +0000 Subject: [PATCH 16/19] delete state_dict print --- src/diffusers/loaders/peft.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 26ce3bc02aed..8383e22d3065 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -233,9 +233,6 @@ def load_lora_adapter( is_control_lora = "lora_controlnet" in state_dict if is_control_lora: state_dict = convert_control_lora_state_dict_to_peft(state_dict) - with open("state_dict1.txt", "w") as f: - for key, val in state_dict.items(): - f.write(f"{key}: {val.shape}\n") rank = {} for key, val in state_dict.items(): From 23cba1804f65f6f9a48097ceb83aaa88e3b66cb0 Mon Sep 17 00:00:00 2001 From: lavinal712 Date: Sat, 5 Jul 2025 08:18:34 +0000 Subject: [PATCH 17/19] fix alpha --- src/diffusers/loaders/peft.py | 10 ++++++++-- src/diffusers/utils/peft_utils.py | 2 +- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 8383e22d3065..b09a7a92e58f 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -229,7 +229,8 @@ def load_lora_adapter( state_dict = convert_unet_state_dict_to_peft(state_dict) # Control LoRA from SAI is different from BFL Control LoRA - # https://huggingface.co/stabilityai/control-lora/ + # https://huggingface.co/stabilityai/control-lora + # https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors is_control_lora = "lora_controlnet" in state_dict if is_control_lora: state_dict = convert_control_lora_state_dict_to_peft(state_dict) @@ -264,9 +265,14 @@ def load_lora_adapter( model_state_dict=self.state_dict(), adapter_name=adapter_name, ) + + # Adjust LoRA config for Control LoRA if is_control_lora: + lora_config.lora_alpha = lora_config.r + lora_config.alpha_pattern = lora_config.rank_pattern + lora_config.bias = "all" lora_config.modules_to_save = lora_config.exclude_modules - lora_config.exclude_modules = [] + lora_config.exclude_modules = None # Date: Wed, 20 Aug 2025 21:23:04 +0800 Subject: [PATCH 18/19] Create control_lora.py --- .../control_lora/control_lora.py | 53 +++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 examples/research_projects/control_lora/control_lora.py diff --git a/examples/research_projects/control_lora/control_lora.py b/examples/research_projects/control_lora/control_lora.py new file mode 100644 index 000000000000..435c9c945b55 --- /dev/null +++ b/examples/research_projects/control_lora/control_lora.py @@ -0,0 +1,53 @@ +import cv2 +import numpy as np +from PIL import Image +import torch + +from diffusers import ( + StableDiffusionXLControlNetPipeline, + ControlNetModel, + UNet2DConditionModel, +) +from diffusers import AutoencoderKL +from diffusers.utils import load_image, make_image_grid + +pipe_id = "stabilityai/stable-diffusion-xl-base-1.0" +lora_id = "stabilityai/control-lora" +lora_filename = "control-LoRAs-rank128/control-lora-canny-rank128.safetensors" + +unet = UNet2DConditionModel.from_pretrained(pipe_id, subfolder="unet", torch_dtype=torch.bfloat16).to("cuda") +controlnet = ControlNetModel.from_unet(unet).to(device="cuda", dtype=torch.bfloat16) +controlnet.load_lora_adapter(lora_id, weight_name=lora_filename, prefix=None, controlnet_config=controlnet.config) + +prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" +negative_prompt = "low quality, bad quality, sketches" + +image = load_image("https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png") + +controlnet_conditioning_scale = 1.0 # recommended for good generalization + +vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae", torch_dtype=torch.bfloat16) +pipe = StableDiffusionXLControlNetPipeline.from_pretrained( + pipe_id, + unet=unet, + controlnet=controlnet, + vae=vae, + torch_dtype=torch.bfloat16, + safety_checker=None, +).to("cuda") + +image = np.array(image) +image = cv2.Canny(image, 100, 200) +image = image[:, :, None] +image = np.concatenate([image, image, image], axis=2) +image = Image.fromarray(image) + +images = pipe( + prompt, negative_prompt=negative_prompt, image=image, + controlnet_conditioning_scale=controlnet_conditioning_scale, + num_images_per_prompt=4 +).images + +final_image = [image] + images +grid = make_image_grid(final_image, 1, 5) +grid.save("hf-logo_canny.png") From 1e8221ce3916cbcfaadf8002e3ea084bf16f65e3 Mon Sep 17 00:00:00 2001 From: Yuqian Hong Date: Wed, 20 Aug 2025 21:23:22 +0800 Subject: [PATCH 19/19] Add files via upload --- .../research_projects/control_lora/README.md | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 examples/research_projects/control_lora/README.md diff --git a/examples/research_projects/control_lora/README.md b/examples/research_projects/control_lora/README.md new file mode 100644 index 000000000000..49aa848e3e0b --- /dev/null +++ b/examples/research_projects/control_lora/README.md @@ -0,0 +1,41 @@ +# Control-LoRA inference example + +Control-LoRA is introduced by Stability AI in [stabilityai/control-lora](https://huggingface.co/stabilityai/control-lora) by adding low-rank parameter efficient fine tuning to ControlNet. This approach offers a more efficient and compact method to bring model control to a wider variety of consumer GPUs. + +## Installing the dependencies + +Before running the scripts, make sure to install the library's training dependencies: + +**Important** + +To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: +```bash +git clone https://github.com/huggingface/diffusers +cd diffusers +pip install . +``` + +Then cd in the example folder and run +```bash +pip install -r requirements.txt +``` + +And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: + +```bash +accelerate config +``` + +## Inference on SDXL + +[stabilityai/control-lora](https://huggingface.co/stabilityai/control-lora) provides a set of Control-LoRA weights for SDXL. Here we use the `canny` condition to generate an image from a text prompt and a reference image. + +```bash +python control_lora.py +``` + +## Acknowledgements + +- [stabilityai/control-lora](https://huggingface.co/stabilityai/control-lora) +- [comfyanonymous/ControlNet-v1-1_fp16_safetensors](https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors) +- [HighCWu/control-lora-v2](https://github.com/HighCWu/control-lora-v2) \ No newline at end of file