diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 297000560962..8b417341ca13 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -34,7 +34,6 @@ delete_adapter_layers, deprecate, get_adapter_name, - get_peft_kwargs, is_accelerate_available, is_peft_available, is_peft_version, @@ -46,14 +45,13 @@ set_adapter_layers, set_weights_and_activate_adapters, ) +from ..utils.peft_utils import _create_lora_config from ..utils.state_dict_utils import _load_sft_state_dict_metadata if is_transformers_available(): from transformers import PreTrainedModel - from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules - if is_peft_available(): from peft.tuners.tuners_utils import BaseTunerLayer @@ -352,8 +350,6 @@ def _load_lora_into_text_encoder( ) peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage - from peft import LoraConfig - # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as # their prefixes. @@ -377,60 +373,25 @@ def _load_lora_into_text_encoder( # convert state dict state_dict = convert_state_dict_to_peft(state_dict) - for name, _ in text_encoder_attn_modules(text_encoder): - for module in ("out_proj", "q_proj", "k_proj", "v_proj"): - rank_key = f"{name}.{module}.lora_B.weight" - if rank_key not in state_dict: - continue - rank[rank_key] = state_dict[rank_key].shape[1] - - for name, _ in text_encoder_mlp_modules(text_encoder): - for module in ("fc1", "fc2"): - rank_key = f"{name}.{module}.lora_B.weight" - if rank_key not in state_dict: - continue - rank[rank_key] = state_dict[rank_key].shape[1] + for name, _ in text_encoder.named_modules(): + if name.endswith((".q_proj", ".k_proj", ".v_proj", ".out_proj", ".fc1", ".fc2")): + rank_key = f"{name}.lora_B.weight" + if rank_key in state_dict: + rank[rank_key] = state_dict[rank_key].shape[1] if network_alphas is not None: alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix] network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys} - if metadata is not None: - lora_config_kwargs = metadata - else: - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False) - - if "use_dora" in lora_config_kwargs: - if lora_config_kwargs["use_dora"]: - if is_peft_version("<", "0.9.0"): - raise ValueError( - "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." - ) - else: - if is_peft_version("<", "0.9.0"): - lora_config_kwargs.pop("use_dora") - - if "lora_bias" in lora_config_kwargs: - if lora_config_kwargs["lora_bias"]: - if is_peft_version("<=", "0.13.2"): - raise ValueError( - "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." - ) - else: - if is_peft_version("<=", "0.13.2"): - lora_config_kwargs.pop("lora_bias") - - try: - lora_config = LoraConfig(**lora_config_kwargs) - except TypeError as e: - raise TypeError("`LoraConfig` class could not be instantiated.") from e + # create `LoraConfig` + lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank, is_unet=False) # adapter_name if adapter_name is None: adapter_name = get_adapter_name(text_encoder) + # if prefix is not None and not state_dict: + model_class_name = text_encoder.__class__.__name__ logger.warning( - f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. " + f"No LoRA keys associated to {model_class_name} found with the {prefix=}. " "This is safe to ignore if LoRA state dict didn't originally have any " - f"{text_encoder.__class__.__name__} related params. You can also try specifying `prefix=None` " + f"{model_class_name} related params. You can also try specifying `prefix=None` " "to resolve the warning. Otherwise, open an issue if you think it's unexpected: " "https://github.com/huggingface/diffusers/issues/new" ) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 6bb6e369368d..3cc3296ce43d 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -29,13 +29,13 @@ convert_unet_state_dict_to_peft, delete_adapter_layers, get_adapter_name, - get_peft_kwargs, is_peft_available, is_peft_version, logging, set_adapter_layers, set_weights_and_activate_adapters, ) +from ..utils.peft_utils import _create_lora_config, _maybe_warn_for_unhandled_keys from .lora_base import _fetch_state_dict, _func_optionally_disable_offloading from .unet_loader_utils import _maybe_expand_lora_scales @@ -64,26 +64,6 @@ } -def _maybe_raise_error_for_ambiguity(config): - rank_pattern = config["rank_pattern"].copy() - target_modules = config["target_modules"] - - for key in list(rank_pattern.keys()): - # try to detect ambiguity - # `target_modules` can also be a str, in which case this loop would loop - # over the chars of the str. The technically correct way to match LoRA keys - # in PEFT is to use LoraModel._check_target_module_exists (lora_config, key). - # But this cuts it for now. - exact_matches = [mod for mod in target_modules if mod == key] - substring_matches = [mod for mod in target_modules if key in mod and mod != key] - - if exact_matches and substring_matches: - if is_peft_version("<", "0.14.1"): - raise ValueError( - "There are ambiguous keys present in this LoRA. To load it, please update your `peft` installation - `pip install -U peft`." - ) - - class PeftAdapterMixin: """ A class containing all functions for loading and using adapters weights that are supported in PEFT library. For @@ -191,7 +171,7 @@ def load_lora_adapter( LoRA adapter metadata. When supplied, the metadata inferred through the state dict isn't used to initialize `LoraConfig`. """ - from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict + from peft import inject_adapter_in_model, set_peft_model_state_dict from peft.tuners.tuners_utils import BaseTunerLayer cache_dir = kwargs.pop("cache_dir", None) @@ -216,7 +196,6 @@ def load_lora_adapter( ) user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, @@ -275,38 +254,8 @@ def load_lora_adapter( k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys } - if metadata is not None: - lora_config_kwargs = metadata - else: - lora_config_kwargs = get_peft_kwargs( - rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict - ) - _maybe_raise_error_for_ambiguity(lora_config_kwargs) - - if "use_dora" in lora_config_kwargs: - if lora_config_kwargs["use_dora"]: - if is_peft_version("<", "0.9.0"): - raise ValueError( - "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." - ) - else: - if is_peft_version("<", "0.9.0"): - lora_config_kwargs.pop("use_dora") - - if "lora_bias" in lora_config_kwargs: - if lora_config_kwargs["lora_bias"]: - if is_peft_version("<=", "0.13.2"): - raise ValueError( - "You need `peft` 0.14.0 at least to use `lora_bias` in LoRAs. Please upgrade your installation of `peft`." - ) - else: - if is_peft_version("<=", "0.13.2"): - lora_config_kwargs.pop("lora_bias") - - try: - lora_config = LoraConfig(**lora_config_kwargs) - except TypeError as e: - raise TypeError("`LoraConfig` class could not be instantiated.") from e + # create LoraConfig + lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank) # adapter_name if adapter_name is None: @@ -317,9 +266,8 @@ def load_lora_adapter( # Now we remove any existing hooks to `_pipeline`. # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks - # otherwise loading LoRA weights will lead to an error + # otherwise loading LoRA weights will lead to an error. is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline) - peft_kwargs = {} if is_peft_version(">=", "0.13.1"): peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage @@ -403,30 +351,7 @@ def map_state_dict_for_hotswap(sd): logger.error(f"Loading {adapter_name} was unsuccessful with the following error: \n{e}") raise - warn_msg = "" - if incompatible_keys is not None: - # Check only for unexpected keys. - unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) - if unexpected_keys: - lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k] - if lora_unexpected_keys: - warn_msg = ( - f"Loading adapter weights from state_dict led to unexpected keys found in the model:" - f" {', '.join(lora_unexpected_keys)}. " - ) - - # Filter missing keys specific to the current adapter. - missing_keys = getattr(incompatible_keys, "missing_keys", None) - if missing_keys: - lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k] - if lora_missing_keys: - warn_msg += ( - f"Loading adapter weights from state_dict led to missing keys in the model:" - f" {', '.join(lora_missing_keys)}." - ) - - if warn_msg: - logger.warning(warn_msg) + _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name) # Offload back. if is_model_cpu_offload: @@ -436,10 +361,11 @@ def map_state_dict_for_hotswap(sd): # Unsafe code /> if prefix is not None and not state_dict: + model_class_name = self.__class__.__name__ logger.warning( - f"No LoRA keys associated to {self.__class__.__name__} found with the {prefix=}. " + f"No LoRA keys associated to {model_class_name} found with the {prefix=}. " "This is safe to ignore if LoRA state dict didn't originally have any " - f"{self.__class__.__name__} related params. You can also try specifying `prefix=None` " + f"{model_class_name} related params. You can also try specifying `prefix=None` " "to resolve the warning. Otherwise, open an issue if you think it's unexpected: " "https://github.com/huggingface/diffusers/issues/new" ) diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 58fe51b6f4dd..1b8a5f6f0020 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -21,9 +21,12 @@ from packaging import version -from .import_utils import is_peft_available, is_torch_available +from . import logging +from .import_utils import is_peft_available, is_peft_version, is_torch_available +logger = logging.get_logger(__name__) + if is_torch_available(): import torch @@ -288,3 +291,83 @@ def check_peft_version(min_version: str) -> None: f"The version of PEFT you are using is not compatible, please use a version that is greater" f" than {min_version}" ) + + +def _create_lora_config( + state_dict, + network_alphas, + metadata, + rank_pattern_dict, + is_unet: bool = True, +): + from peft import LoraConfig + + if metadata is not None: + lora_config_kwargs = metadata + else: + lora_config_kwargs = get_peft_kwargs( + rank_pattern_dict, network_alpha_dict=network_alphas, peft_state_dict=state_dict, is_unet=is_unet + ) + + _maybe_raise_error_for_ambiguous_keys(lora_config_kwargs) + + # Version checks for DoRA and lora_bias + if "use_dora" in lora_config_kwargs and lora_config_kwargs["use_dora"]: + if is_peft_version("<", "0.9.0"): + raise ValueError("DoRA requires PEFT >= 0.9.0. Please upgrade.") + + if "lora_bias" in lora_config_kwargs and lora_config_kwargs["lora_bias"]: + if is_peft_version("<=", "0.13.2"): + raise ValueError("lora_bias requires PEFT >= 0.14.0. Please upgrade.") + + try: + return LoraConfig(**lora_config_kwargs) + except TypeError as e: + raise TypeError("`LoraConfig` class could not be instantiated.") from e + + +def _maybe_raise_error_for_ambiguous_keys(config): + rank_pattern = config["rank_pattern"].copy() + target_modules = config["target_modules"] + + for key in list(rank_pattern.keys()): + # try to detect ambiguity + # `target_modules` can also be a str, in which case this loop would loop + # over the chars of the str. The technically correct way to match LoRA keys + # in PEFT is to use LoraModel._check_target_module_exists (lora_config, key). + # But this cuts it for now. + exact_matches = [mod for mod in target_modules if mod == key] + substring_matches = [mod for mod in target_modules if key in mod and mod != key] + + if exact_matches and substring_matches: + if is_peft_version("<", "0.14.1"): + raise ValueError( + "There are ambiguous keys present in this LoRA. To load it, please update your `peft` installation - `pip install -U peft`." + ) + + +def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name): + warn_msg = "" + if incompatible_keys is not None: + # Check only for unexpected keys. + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k] + if lora_unexpected_keys: + warn_msg = ( + f"Loading adapter weights from state_dict led to unexpected keys found in the model:" + f" {', '.join(lora_unexpected_keys)}. " + ) + + # Filter missing keys specific to the current adapter. + missing_keys = getattr(incompatible_keys, "missing_keys", None) + if missing_keys: + lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k] + if lora_missing_keys: + warn_msg += ( + f"Loading adapter weights from state_dict led to missing keys in the model:" + f" {', '.join(lora_missing_keys)}." + ) + + if warn_msg: + logger.warning(warn_msg) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 3742d395e776..93dc4a2c37e3 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -1794,7 +1794,7 @@ def test_missing_keys_warning(self): missing_key = [k for k in state_dict if "lora_A" in k][0] del state_dict[missing_key] - logger = logging.get_logger("diffusers.loaders.peft") + logger = logging.get_logger("diffusers.utils.peft_utils") logger.setLevel(30) with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(state_dict) @@ -1829,7 +1829,7 @@ def test_unexpected_keys_warning(self): unexpected_key = [k for k in state_dict if "lora_A" in k][0] + ".diffusers_cat" state_dict[unexpected_key] = torch.tensor(1.0, device=torch_device) - logger = logging.get_logger("diffusers.loaders.peft") + logger = logging.get_logger("diffusers.utils.peft_utils") logger.setLevel(30) with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(state_dict) @@ -2006,9 +2006,6 @@ def test_lora_B_bias(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) - logger = logging.get_logger("diffusers.loaders.lora_pipeline") - logger.setLevel(logging.INFO) - original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] denoiser_lora_config.lora_bias = False