-
Couldn't load subscription status.
- Fork 6.4k
[LoRA] Implement hot-swapping of LoRA #9453
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 20 commits
d3fbd7b
84bae62
63ece9d
94c669c
c7378ed
7c67b38
ea12e0d
ec4b0d5
e07323a
529a523
ac1346d
58b35ba
d21a988
488f2f0
ece3d0f
5ab1460
bc157e6
bd1da66
119a8ed
53c2f84
a715559
e40390d
1b834ec
4b01401
2cd3665
efbd820
e735ac2
69b637d
3a6677c
a96f3fd
deab0eb
2c6b435
ccb45f7
09e2ec7
67ab6bf
f03fe6b
2d407ca
6b59ecf
f14146f
a79876d
c3c1bdf
387ddf6
7f72d0b
dec4d10
204f521
716f446
4d82111
425cb39
115c77d
5d90753
62c1c13
d6d23b8
366632d
b181a47
f2a6146
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -138,7 +138,9 @@ def _optionally_disable_offloading(cls, _pipeline): | |||
| """ | ||||
| return _func_optionally_disable_offloading(_pipeline=_pipeline) | ||||
|
|
||||
| def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="transformer", **kwargs): | ||||
| def load_lora_adapter( | ||||
| self, pretrained_model_name_or_path_or_dict, prefix="transformer", hotswap: bool = False, **kwargs | ||||
| ): | ||||
| r""" | ||||
| Loads a LoRA adapter into the underlying model. | ||||
|
|
||||
|
|
@@ -182,6 +184,28 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans | |||
| low_cpu_mem_usage (`bool`, *optional*): | ||||
| Speed up model loading by only loading the pretrained LoRA weights and not initializing the random | ||||
| weights. | ||||
| hotswap : (`bool`, *optional*) | ||||
| Defaults to `False`. Whether to substitute an existing adapter with the newly loaded adapter in-place. | ||||
| This means that, instead of loading an additional adapter, this will take the existing adapter weights | ||||
| and replace them with the weights of the new adapter. This can be faster and more memory efficient. | ||||
| However, the main advantage of hotswapping is that when the model is compiled with torch.compile, | ||||
| loading the new adapter does not require recompilation of the model. | ||||
|
|
||||
| If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need | ||||
| to call an additional method before loading the adapter: | ||||
|
|
||||
| ```py | ||||
| from peft.utils.hotswap import prepare_model_for_compiled_hotswap | ||||
|
|
||||
| model = ... # load diffusers model with first LoRA adapter | ||||
| max_rank = ... # the highest rank among all LoRAs that you want to load | ||||
| prepare_model_for_compiled_hotswap(model, target_rank=max_rank) # call *before* compiling | ||||
| model = torch.compile(model) | ||||
| model.load_lora_adapter(..., hotswap=True) # now hotswap the 2nd adapter | ||||
| ``` | ||||
|
|
||||
| There are some limitations to this technique, which are documented here: | ||||
| https://huggingface.co/docs/peft/main/en/package_reference/hotswap | ||||
| """ | ||||
| from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict | ||||
| from peft.tuners.tuners_utils import BaseTunerLayer | ||||
|
|
@@ -235,10 +259,15 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans | |||
| state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in model_keys} | ||||
|
|
||||
| if len(state_dict) > 0: | ||||
| if adapter_name in getattr(self, "peft_config", {}): | ||||
| if adapter_name in getattr(self, "peft_config", {}) and not hotswap: | ||||
| raise ValueError( | ||||
| f"Adapter name {adapter_name} already in use in the model - please select a new adapter name." | ||||
| ) | ||||
| elif adapter_name not in getattr(self, "peft_config", {}) and hotswap: | ||||
| raise ValueError( | ||||
| f"Trying to hotswap LoRA adapter '{adapter_name}' but there is no existing adapter by that name. " | ||||
| "Please choose an existing adapter name or set `hotswap=False` to prevent hotswapping." | ||||
| ) | ||||
|
|
||||
| # check with first key if is not in peft format | ||||
| first_key = next(iter(state_dict.keys())) | ||||
|
|
@@ -296,11 +325,47 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans | |||
| if is_peft_version(">=", "0.13.1"): | ||||
| peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage | ||||
|
|
||||
| if hotswap: | ||||
| try: | ||||
| from peft.utils.hotswap import check_hotswap_configs_compatible, hotswap_adapter_from_state_dict | ||||
| except ImportError as exc: | ||||
| msg = ( | ||||
| "Hotswapping requires PEFT > v0.14. Please upgrade PEFT to a higher version or install it " | ||||
| "from source." | ||||
| ) | ||||
| raise ImportError(msg) from exc | ||||
|
||||
| def is_peft_version(operation: str, version: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did this on purpose, as it allows to test the feature by installing PEFT from main. Otherwise, we'd have to wait for the next PEFT release. Normally, I'd also avoid try import ... for the side effect, but at this point, PEFT is already imported, so that's not a factor.
If you still want me to change this, LMK.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't we do is_peft_version(">", "0.14.0")?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yes, that should work, I'll fix.
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -281,7 +281,14 @@ def _process_custom_diffusion(self, state_dict): | |||||||||||||||||||||||||||||||||
| return attn_processors | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def _process_lora( | ||||||||||||||||||||||||||||||||||
| self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline, low_cpu_mem_usage | ||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||
| state_dict, | ||||||||||||||||||||||||||||||||||
| unet_identifier_key, | ||||||||||||||||||||||||||||||||||
| network_alphas, | ||||||||||||||||||||||||||||||||||
| adapter_name, | ||||||||||||||||||||||||||||||||||
| _pipeline, | ||||||||||||||||||||||||||||||||||
| low_cpu_mem_usage, | ||||||||||||||||||||||||||||||||||
| hotswap: bool = False, | ||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||
| # This method does the following things: | ||||||||||||||||||||||||||||||||||
| # 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy | ||||||||||||||||||||||||||||||||||
|
|
@@ -294,6 +301,7 @@ def _process_lora( | |||||||||||||||||||||||||||||||||
| raise ValueError("PEFT backend is required for this method.") | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict | ||||||||||||||||||||||||||||||||||
| from peft.tuners.tuners_utils import BaseTunerLayer | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| keys = list(state_dict.keys()) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
@@ -313,10 +321,15 @@ def _process_lora( | |||||||||||||||||||||||||||||||||
| state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| if len(state_dict_to_be_used) > 0: | ||||||||||||||||||||||||||||||||||
| if adapter_name in getattr(self, "peft_config", {}): | ||||||||||||||||||||||||||||||||||
| if adapter_name in getattr(self, "peft_config", {}) and not hotswap: | ||||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||||
| f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name." | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| elif adapter_name not in getattr(self, "peft_config", {}) and hotswap: | ||||||||||||||||||||||||||||||||||
sayakpaul marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||||
| f"Trying to hotswap LoRA adapter '{adapter_name}' but there is no existing adapter by that name. " | ||||||||||||||||||||||||||||||||||
| "Please choose an existing adapter name or set `hotswap=False` to prevent hotswapping." | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| state_dict = convert_unet_state_dict_to_peft(state_dict_to_be_used) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
@@ -364,8 +377,59 @@ def _process_lora( | |||||||||||||||||||||||||||||||||
| if is_peft_version(">=", "0.13.1"): | ||||||||||||||||||||||||||||||||||
| peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs) | ||||||||||||||||||||||||||||||||||
| incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs) | ||||||||||||||||||||||||||||||||||
| if hotswap: | ||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||
| from peft.utils.hotswap import check_hotswap_configs_compatible, hotswap_adapter_from_state_dict | ||||||||||||||||||||||||||||||||||
| except ImportError as exc: | ||||||||||||||||||||||||||||||||||
| msg = ( | ||||||||||||||||||||||||||||||||||
| "Hotswapping requires PEFT > v0.14. Please upgrade PEFT to a higher version or install it " | ||||||||||||||||||||||||||||||||||
| "from source." | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| raise ImportError(msg) from exc | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
| def is_peft_version(operation: str, version: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as above.
BenjaminBossan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
BenjaminBossan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we raise the exception properly instead of logging an error?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, are we testing if this error is raised?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that I just copied the pattern from here:
diffusers/src/diffusers/loaders/peft.py
Lines 301 to 316 in 97abdd2
| try: | |
| inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs) | |
| incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs) | |
| except Exception as e: | |
| # In case `inject_adapter_in_model()` was unsuccessful even before injecting the `peft_config`. | |
| if hasattr(self, "peft_config"): | |
| for module in self.modules(): | |
| if isinstance(module, BaseTunerLayer): | |
| active_adapters = module.active_adapters | |
| for active_adapter in active_adapters: | |
| if adapter_name in active_adapter: | |
| module.delete_adapter(adapter_name) | |
| self.peft_config.pop(adapter_name) | |
| logger.error(f"Loading {adapter_name} was unsucessful with the following error: \n{e}") | |
| raise |
So this is just for consistency.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Let me provide some reasoning as to why I added the error (referenced) that way.
PEFT already raises a nice error when the code reaches that part in PEFT. So, it didn't make sense to craft another error message on top of it and instead, we decided to just propagate it to the users coming via diffusers.
I think that is okay to do here.
Regardless, @BenjaminBossan are we testing for the error that should be raised here in case hotswap fails? Or no need?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we don't need a test that is specific for hotswap failing, as it doesn't really matter why loading the adapter fails. If a test is added, it should probably be something similar to the test that was added when the change was introduced in peft.py: https://github.com/huggingface/diffusers/pull/10188/files#diff-b544edcc938e163009735ef4fa963abd0a41615c175552160c9e0f94ceb7f552.
Not sure if it's possible, but maybe that test can be adjusted to trigger this code path?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I can look into it in a future PR. Possible to add a note?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a TODO comment.
Uh oh!
There was an error while loading. Please reload this page.