Skip to content

Commit 984b8c9

Browse files
committed
updates.
1 parent 4d307cc commit 984b8c9

File tree

2 files changed

+16
-88
lines changed

2 files changed

+16
-88
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 9 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
USE_PEFT_BACKEND,
2222
convert_state_dict_to_diffusers,
2323
convert_state_dict_to_peft,
24-
convert_unet_state_dict_to_peft,
2524
deprecate,
2625
get_adapter_name,
2726
get_peft_kwargs,
@@ -1845,92 +1844,15 @@ def load_lora_into_transformer(
18451844
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
18461845
)
18471846

1848-
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
1849-
1850-
keys = list(state_dict.keys())
1851-
1852-
transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
1853-
state_dict = {
1854-
k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
1855-
}
1856-
1857-
if len(state_dict.keys()) > 0:
1858-
# check with first key if is not in peft format
1859-
first_key = next(iter(state_dict.keys()))
1860-
if "lora_A" not in first_key:
1861-
state_dict = convert_unet_state_dict_to_peft(state_dict)
1862-
1863-
if adapter_name in getattr(transformer, "peft_config", {}):
1864-
raise ValueError(
1865-
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
1866-
)
1867-
1868-
rank = {}
1869-
for key, val in state_dict.items():
1870-
if "lora_B" in key:
1871-
rank[key] = val.shape[1]
1872-
1873-
if network_alphas is not None and len(network_alphas) >= 1:
1874-
prefix = cls.transformer_name
1875-
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
1876-
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
1877-
1878-
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
1879-
if "use_dora" in lora_config_kwargs:
1880-
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
1881-
raise ValueError(
1882-
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
1883-
)
1884-
else:
1885-
lora_config_kwargs.pop("use_dora")
1886-
lora_config = LoraConfig(**lora_config_kwargs)
1887-
1888-
# adapter_name
1889-
if adapter_name is None:
1890-
adapter_name = get_adapter_name(transformer)
1891-
1892-
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
1893-
# otherwise loading LoRA weights will lead to an error
1894-
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
1895-
1896-
peft_kwargs = {}
1897-
if is_peft_version(">=", "0.13.1"):
1898-
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
1899-
1900-
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
1901-
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
1902-
1903-
warn_msg = ""
1904-
if incompatible_keys is not None:
1905-
# Check only for unexpected keys.
1906-
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
1907-
if unexpected_keys:
1908-
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
1909-
if lora_unexpected_keys:
1910-
warn_msg = (
1911-
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
1912-
f" {', '.join(lora_unexpected_keys)}. "
1913-
)
1914-
1915-
# Filter missing keys specific to the current adapter.
1916-
missing_keys = getattr(incompatible_keys, "missing_keys", None)
1917-
if missing_keys:
1918-
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
1919-
if lora_missing_keys:
1920-
warn_msg += (
1921-
f"Loading adapter weights from state_dict led to missing keys in the model:"
1922-
f" {', '.join(lora_missing_keys)}."
1923-
)
1924-
1925-
if warn_msg:
1926-
logger.warning(warn_msg)
1927-
1928-
# Offload back.
1929-
if is_model_cpu_offload:
1930-
_pipeline.enable_model_cpu_offload()
1931-
elif is_sequential_cpu_offload:
1932-
_pipeline.enable_sequential_cpu_offload()
1933-
# Unsafe code />
1847+
# Load the layers corresponding to transformer.
1848+
logger.info(f"Loading {cls.transformer_name}.")
1849+
transformer.load_lora_adapter(
1850+
state_dict,
1851+
network_alphas=network_alphas,
1852+
adapter_name=adapter_name,
1853+
_pipeline=_pipeline,
1854+
low_cpu_mem_usage=low_cpu_mem_usage,
1855+
)
19341856

19351857
@classmethod
19361858
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder

src/diffusers/loaders/peft.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
116116
weight_name = kwargs.pop("weight_name", None)
117117
use_safetensors = kwargs.pop("use_safetensors", None)
118118
adapter_name = kwargs.pop("adapter_name", None)
119+
network_alphas = kwargs.pop("network_alphas", None)
119120
_pipeline = kwargs.pop("_pipeline", None)
120121
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
121122
allow_pickle = False
@@ -166,7 +167,11 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
166167
if "lora_B" in key:
167168
rank[key] = val.shape[1]
168169

169-
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict)
170+
if network_alphas is not None and len(network_alphas) >= 1:
171+
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
172+
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
173+
174+
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
170175
if "use_dora" in lora_config_kwargs:
171176
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
172177
raise ValueError(
@@ -187,6 +192,7 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
187192
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
188193
# otherwise loading LoRA weights will lead to an error
189194
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
195+
190196
peft_kwargs = {}
191197
if is_peft_version(">=", "0.13.1"):
192198
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage

0 commit comments

Comments
 (0)