Skip to content

Commit bb0d4a1

Browse files
committed
finish./
1 parent 8c988f4 commit bb0d4a1

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

src/diffusers/loaders/peft.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@
5454
"SanaTransformer2DModel": lambda model_cls, weights: weights,
5555
}
5656
_NO_CONFIG_UPDATE_KEYS = ["to_k", "to_q", "to_v"]
57-
_FULL_NAME_PREFIX_FOR_PEFT = "FULL-NAME"
5857

5958

6059
def _maybe_adjust_config(config):
@@ -189,7 +188,11 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
189188
"""
190189
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
191190
from peft.tuners.tuners_utils import BaseTunerLayer
192-
from peft.utils.constants import FULLY_QUALIFIED_PATTERN_KEY_PREFIX
191+
192+
try:
193+
from peft.utils.constants import FULLY_QUALIFIED_PATTERN_KEY_PREFIX
194+
except ImportError:
195+
FULLY_QUALIFIED_PATTERN_KEY_PREFIX = None
193196

194197
cache_dir = kwargs.pop("cache_dir", None)
195198
force_download = kwargs.pop("force_download", False)
@@ -255,14 +258,22 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
255258
# Cannot figure out rank from lora layers that don't have atleast 2 dimensions.
256259
# Bias layers in LoRA only have a single dimension
257260
if "lora_B" in key and val.ndim > 1:
258-
rank[f"{FULLY_QUALIFIED_PATTERN_KEY_PREFIX}{key}"] = val.shape[1]
261+
# Support to handle cases where layer patterns are treated as full layer names
262+
# was added later in PEFT. So, we handle it accordingly.
263+
# TODO: when we fix the minimal PEFT version for Diffusers,
264+
# we should remove `_maybe_adjust_config()`.
265+
if FULLY_QUALIFIED_PATTERN_KEY_PREFIX:
266+
rank[f"{FULLY_QUALIFIED_PATTERN_KEY_PREFIX}{key}"] = val.shape[1]
267+
else:
268+
rank[key] = val.shape[1]
259269

260270
if network_alphas is not None and len(network_alphas) >= 1:
261271
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
262272
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
263273

264274
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
265-
# lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs)
275+
if not FULLY_QUALIFIED_PATTERN_KEY_PREFIX:
276+
lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs)
266277

267278
if "use_dora" in lora_config_kwargs:
268279
if lora_config_kwargs["use_dora"]:

0 commit comments

Comments
 (0)