Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/diffusers/loaders/lora_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def _load_lora_into_text_encoder(

# Load the layers corresponding to text encoder and make necessary adjustments.
if prefix is not None:
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
state_dict = {k.removeprefix(prefix + "."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}

if len(state_dict) > 0:
logger.info(f"Loading {prefix}.")
Expand All @@ -374,7 +374,7 @@ def _load_lora_into_text_encoder(

if network_alphas is not None:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
network_alphas = {k.removeprefix(prefix + "."): v for k, v in network_alphas.items() if k in alpha_keys}

lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)

Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2103,7 +2103,7 @@ def _load_norm_into_transformer(
prefix = prefix or cls.transformer_name
for key in list(state_dict.keys()):
if key.split(".")[0] == prefix:
state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key)
state_dict[key.removeprefix(prefix + ".")] = state_dict.pop(key)

# Find invalid keys
transformer_state_dict = transformer.state_dict()
Expand Down Expand Up @@ -2425,7 +2425,7 @@ def _maybe_expand_transformer_param_shape_or_error_(
prefix = prefix or cls.transformer_name
for key in list(state_dict.keys()):
if key.split(".")[0] == prefix:
state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key)
state_dict[key.removeprefix(prefix + ".")] = state_dict.pop(key)

# Expand transformer parameter shapes if they don't match lora
has_param_with_shape_update = False
Expand Down
6 changes: 4 additions & 2 deletions src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def load_lora_adapter(
raise ValueError("`network_alphas` cannot be None when `prefix` is None.")

if prefix is not None:
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
state_dict = {k.removeprefix(prefix + "."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}

if len(state_dict) > 0:
if adapter_name in getattr(self, "peft_config", {}) and not hotswap:
Expand Down Expand Up @@ -261,7 +261,9 @@ def load_lora_adapter(

if network_alphas is not None and len(network_alphas) >= 1:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
network_alphas = {
k.removeprefix(prefix + "."): v for k, v in network_alphas.items() if k in alpha_keys
}

lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
_maybe_raise_error_for_ambiguity(lora_config_kwargs)
Expand Down