Skip to content

Commit 32bc07d

Browse files
committed
factor out stuff from load_lora_adapter().
1 parent 8adc600 commit 32bc07d

File tree

2 files changed

+209
-176
lines changed

2 files changed

+209
-176
lines changed

src/diffusers/loaders/peft.py

Lines changed: 95 additions & 175 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,18 @@
2929
convert_unet_state_dict_to_peft,
3030
delete_adapter_layers,
3131
get_adapter_name,
32-
get_peft_kwargs,
3332
is_peft_available,
3433
is_peft_version,
3534
logging,
3635
set_adapter_layers,
3736
set_weights_and_activate_adapters,
3837
)
38+
from ..utils.peft_utils import (
39+
_create_lora_config,
40+
_lora_loading_context,
41+
_maybe_warn_for_unhandled_keys,
42+
_maybe_warn_if_no_keys_found,
43+
)
3944
from .lora_base import _fetch_state_dict, _func_optionally_disable_offloading
4045
from .unet_loader_utils import _maybe_expand_lora_scales
4146

@@ -64,26 +69,6 @@
6469
}
6570

6671

67-
def _maybe_raise_error_for_ambiguity(config):
68-
rank_pattern = config["rank_pattern"].copy()
69-
target_modules = config["target_modules"]
70-
71-
for key in list(rank_pattern.keys()):
72-
# try to detect ambiguity
73-
# `target_modules` can also be a str, in which case this loop would loop
74-
# over the chars of the str. The technically correct way to match LoRA keys
75-
# in PEFT is to use LoraModel._check_target_module_exists (lora_config, key).
76-
# But this cuts it for now.
77-
exact_matches = [mod for mod in target_modules if mod == key]
78-
substring_matches = [mod for mod in target_modules if key in mod and mod != key]
79-
80-
if exact_matches and substring_matches:
81-
if is_peft_version("<", "0.14.1"):
82-
raise ValueError(
83-
"There are ambiguous keys present in this LoRA. To load it, please update your `peft` installation - `pip install -U peft`."
84-
)
85-
86-
8772
class PeftAdapterMixin:
8873
"""
8974
A class containing all functions for loading and using adapters weights that are supported in PEFT library. For
@@ -189,7 +174,7 @@ def load_lora_adapter(
189174
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
190175
metadata: TODO
191176
"""
192-
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
177+
from peft import inject_adapter_in_model, set_peft_model_state_dict
193178
from peft.tuners.tuners_utils import BaseTunerLayer
194179

195180
cache_dir = kwargs.pop("cache_dir", None)
@@ -214,7 +199,6 @@ def load_lora_adapter(
214199
)
215200

216201
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
217-
218202
state_dict, metadata = _fetch_state_dict(
219203
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
220204
weight_name=weight_name,
@@ -273,38 +257,8 @@ def load_lora_adapter(
273257
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
274258
}
275259

276-
if metadata is not None:
277-
lora_config_kwargs = metadata
278-
else:
279-
lora_config_kwargs = get_peft_kwargs(
280-
rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict
281-
)
282-
_maybe_raise_error_for_ambiguity(lora_config_kwargs)
283-
284-
if "use_dora" in lora_config_kwargs:
285-
if lora_config_kwargs["use_dora"]:
286-
if is_peft_version("<", "0.9.0"):
287-
raise ValueError(
288-
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
289-
)
290-
else:
291-
if is_peft_version("<", "0.9.0"):
292-
lora_config_kwargs.pop("use_dora")
293-
294-
if "lora_bias" in lora_config_kwargs:
295-
if lora_config_kwargs["lora_bias"]:
296-
if is_peft_version("<=", "0.13.2"):
297-
raise ValueError(
298-
"You need `peft` 0.14.0 at least to use `lora_bias` in LoRAs. Please upgrade your installation of `peft`."
299-
)
300-
else:
301-
if is_peft_version("<=", "0.13.2"):
302-
lora_config_kwargs.pop("lora_bias")
303-
304-
try:
305-
lora_config = LoraConfig(**lora_config_kwargs)
306-
except TypeError as e:
307-
raise TypeError("`LoraConfig` class could not be instantiated.") from e
260+
# create LoraConfig
261+
lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank, self.lora_layer_modules)
308262

309263
# adapter_name
310264
if adapter_name is None:
@@ -315,132 +269,98 @@ def load_lora_adapter(
315269
# Now we remove any existing hooks to `_pipeline`.
316270

317271
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
318-
# otherwise loading LoRA weights will lead to an error
319-
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
320-
321-
peft_kwargs = {}
322-
if is_peft_version(">=", "0.13.1"):
323-
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
324-
325-
if hotswap or (self._prepare_lora_hotswap_kwargs is not None):
326-
if is_peft_version(">", "0.14.0"):
327-
from peft.utils.hotswap import (
328-
check_hotswap_configs_compatible,
329-
hotswap_adapter_from_state_dict,
330-
prepare_model_for_compiled_hotswap,
331-
)
332-
else:
333-
msg = (
334-
"Hotswapping requires PEFT > v0.14. Please upgrade PEFT to a higher version or install it "
335-
"from source."
336-
)
337-
raise ImportError(msg)
338-
339-
if hotswap:
340-
341-
def map_state_dict_for_hotswap(sd):
342-
# For hotswapping, we need the adapter name to be present in the state dict keys
343-
new_sd = {}
344-
for k, v in sd.items():
345-
if k.endswith("lora_A.weight") or key.endswith("lora_B.weight"):
346-
k = k[: -len(".weight")] + f".{adapter_name}.weight"
347-
elif k.endswith("lora_B.bias"): # lora_bias=True option
348-
k = k[: -len(".bias")] + f".{adapter_name}.bias"
349-
new_sd[k] = v
350-
return new_sd
351-
352-
# To handle scenarios where we cannot successfully set state dict. If it's unsuccessful,
353-
# we should also delete the `peft_config` associated to the `adapter_name`.
354-
try:
355-
if hotswap:
356-
state_dict = map_state_dict_for_hotswap(state_dict)
357-
check_hotswap_configs_compatible(self.peft_config[adapter_name], lora_config)
358-
try:
359-
hotswap_adapter_from_state_dict(
360-
model=self,
361-
state_dict=state_dict,
362-
adapter_name=adapter_name,
363-
config=lora_config,
272+
# otherwise loading LoRA weights will lead to an error. So, we use a context manager here
273+
# that takes care of enabling and disabling offloading in the pipeline automatically.
274+
with _lora_loading_context(_pipeline):
275+
peft_kwargs = {}
276+
if is_peft_version(">=", "0.13.1"):
277+
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
278+
279+
if hotswap or (self._prepare_lora_hotswap_kwargs is not None):
280+
if is_peft_version(">", "0.14.0"):
281+
from peft.utils.hotswap import (
282+
check_hotswap_configs_compatible,
283+
hotswap_adapter_from_state_dict,
284+
prepare_model_for_compiled_hotswap,
364285
)
365-
except Exception as e:
366-
logger.error(f"Hotswapping {adapter_name} was unsuccessful with the following error: \n{e}")
367-
raise
368-
# the hotswap function raises if there are incompatible keys, so if we reach this point we can set
369-
# it to None
370-
incompatible_keys = None
371-
else:
372-
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
373-
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
374-
375-
if self._prepare_lora_hotswap_kwargs is not None:
376-
# For hotswapping of compiled models or adapters with different ranks.
377-
# If the user called enable_lora_hotswap, we need to ensure it is called:
378-
# - after the first adapter was loaded
379-
# - before the model is compiled and the 2nd adapter is being hotswapped in
380-
# Therefore, it needs to be called here
381-
prepare_model_for_compiled_hotswap(
382-
self, config=lora_config, **self._prepare_lora_hotswap_kwargs
383-
)
384-
# We only want to call prepare_model_for_compiled_hotswap once
385-
self._prepare_lora_hotswap_kwargs = None
386-
387-
# Set peft config loaded flag to True if module has been successfully injected and incompatible keys retrieved
388-
if not self._hf_peft_config_loaded:
389-
self._hf_peft_config_loaded = True
390-
except Exception as e:
391-
# In case `inject_adapter_in_model()` was unsuccessful even before injecting the `peft_config`.
392-
if hasattr(self, "peft_config"):
393-
for module in self.modules():
394-
if isinstance(module, BaseTunerLayer):
395-
active_adapters = module.active_adapters
396-
for active_adapter in active_adapters:
397-
if adapter_name in active_adapter:
398-
module.delete_adapter(adapter_name)
399-
400-
self.peft_config.pop(adapter_name)
401-
logger.error(f"Loading {adapter_name} was unsuccessful with the following error: \n{e}")
402-
raise
403-
404-
warn_msg = ""
405-
if incompatible_keys is not None:
406-
# Check only for unexpected keys.
407-
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
408-
if unexpected_keys:
409-
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
410-
if lora_unexpected_keys:
411-
warn_msg = (
412-
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
413-
f" {', '.join(lora_unexpected_keys)}. "
414-
)
415-
416-
# Filter missing keys specific to the current adapter.
417-
missing_keys = getattr(incompatible_keys, "missing_keys", None)
418-
if missing_keys:
419-
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
420-
if lora_missing_keys:
421-
warn_msg += (
422-
f"Loading adapter weights from state_dict led to missing keys in the model:"
423-
f" {', '.join(lora_missing_keys)}."
286+
else:
287+
msg = (
288+
"Hotswapping requires PEFT > v0.14. Please upgrade PEFT to a higher version or install it "
289+
"from source."
424290
)
291+
raise ImportError(msg)
425292

426-
if warn_msg:
427-
logger.warning(warn_msg)
293+
if hotswap:
428294

429-
# Offload back.
430-
if is_model_cpu_offload:
431-
_pipeline.enable_model_cpu_offload()
432-
elif is_sequential_cpu_offload:
433-
_pipeline.enable_sequential_cpu_offload()
295+
def map_state_dict_for_hotswap(sd):
296+
# For hotswapping, we need the adapter name to be present in the state dict keys
297+
new_sd = {}
298+
for k, v in sd.items():
299+
if k.endswith("lora_A.weight") or key.endswith("lora_B.weight"):
300+
k = k[: -len(".weight")] + f".{adapter_name}.weight"
301+
elif k.endswith("lora_B.bias"): # lora_bias=True option
302+
k = k[: -len(".bias")] + f".{adapter_name}.bias"
303+
new_sd[k] = v
304+
return new_sd
305+
306+
# To handle scenarios where we cannot successfully set state dict. If it's unsuccessful,
307+
# we should also delete the `peft_config` associated to the `adapter_name`.
308+
try:
309+
if hotswap:
310+
state_dict = map_state_dict_for_hotswap(state_dict)
311+
check_hotswap_configs_compatible(self.peft_config[adapter_name], lora_config)
312+
try:
313+
hotswap_adapter_from_state_dict(
314+
model=self,
315+
state_dict=state_dict,
316+
adapter_name=adapter_name,
317+
config=lora_config,
318+
)
319+
except Exception as e:
320+
logger.error(
321+
f"Hotswapping {adapter_name} was unsuccessful with the following error: \n{e}"
322+
)
323+
raise
324+
# the hotswap function raises if there are incompatible keys, so if we reach this point we can set
325+
# it to None
326+
incompatible_keys = None
327+
else:
328+
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
329+
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
330+
331+
if self._prepare_lora_hotswap_kwargs is not None:
332+
# For hotswapping of compiled models or adapters with different ranks.
333+
# If the user called enable_lora_hotswap, we need to ensure it is called:
334+
# - after the first adapter was loaded
335+
# - before the model is compiled and the 2nd adapter is being hotswapped in
336+
# Therefore, it needs to be called here
337+
prepare_model_for_compiled_hotswap(
338+
self, config=lora_config, **self._prepare_lora_hotswap_kwargs
339+
)
340+
# We only want to call prepare_model_for_compiled_hotswap once
341+
self._prepare_lora_hotswap_kwargs = None
342+
343+
# Set peft config loaded flag to True if module has been successfully injected and incompatible keys retrieved
344+
if not self._hf_peft_config_loaded:
345+
self._hf_peft_config_loaded = True
346+
except Exception as e:
347+
# In case `inject_adapter_in_model()` was unsuccessful even before injecting the `peft_config`.
348+
if hasattr(self, "peft_config"):
349+
for module in self.modules():
350+
if isinstance(module, BaseTunerLayer):
351+
active_adapters = module.active_adapters
352+
for active_adapter in active_adapters:
353+
if adapter_name in active_adapter:
354+
module.delete_adapter(adapter_name)
355+
356+
self.peft_config.pop(adapter_name)
357+
logger.error(f"Loading {adapter_name} was unsuccessful with the following error: \n{e}")
358+
raise
434359
# Unsafe code />
435360

436-
if prefix is not None and not state_dict:
437-
logger.warning(
438-
f"No LoRA keys associated to {self.__class__.__name__} found with the {prefix=}. "
439-
"This is safe to ignore if LoRA state dict didn't originally have any "
440-
f"{self.__class__.__name__} related params. You can also try specifying `prefix=None` "
441-
"to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
442-
"https://github.com/huggingface/diffusers/issues/new"
443-
)
361+
_maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name)
362+
363+
_maybe_warn_if_no_keys_found(state_dict, prefix, model_class_name=self.__class__.__name__)
444364

445365
def save_lora_adapter(
446366
self,

0 commit comments

Comments
 (0)