-
Couldn't load subscription status.
- Fork 31k
ENH: Add support for LoRA hotswapping #41297
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
95e021c
7ace8f7
664e432
e6f5857
7c05105
a798e53
38da94e
24f0547
9fdb840
196dd76
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,7 +15,7 @@ | |
| import importlib | ||
| import inspect | ||
| import re | ||
| from typing import Any, Optional, Union | ||
| from typing import Any, Literal, Optional, Union | ||
|
|
||
| from packaging import version | ||
|
|
||
|
|
@@ -83,6 +83,7 @@ class PeftAdapterMixin: | |
| """ | ||
|
|
||
| _hf_peft_config_loaded = False | ||
| _prepare_peft_hotswap_kwargs: Optional[dict] = None | ||
|
|
||
| def load_adapter( | ||
| self, | ||
|
|
@@ -98,6 +99,7 @@ def load_adapter( | |
| adapter_state_dict: Optional[dict[str, "torch.Tensor"]] = None, | ||
| low_cpu_mem_usage: bool = False, | ||
| is_trainable: bool = False, | ||
| hotswap: bool = False, | ||
| adapter_kwargs: Optional[dict[str, Any]] = None, | ||
| ) -> None: | ||
| """ | ||
|
|
@@ -153,12 +155,52 @@ def load_adapter( | |
| is_trainable (`bool`, *optional*, defaults to `False`): | ||
| Whether the adapter should be trainable or not. If `False`, the adapter will be frozen and can only be | ||
| used for inference. | ||
| hotswap : (`bool`, *optional*, defaults to `False`) | ||
| Whether to substitute an existing (LoRA) adapter with the newly loaded adapter in-place. This means | ||
| that, instead of loading an additional adapter, this will take the existing adapter weights and replace | ||
| them with the weights of the new adapter. This can be faster and more memory efficient. However, the | ||
| main advantage of hotswapping is that when the model is compiled with torch.compile, loading the new | ||
| adapter does not require recompilation of the model. When using hotswapping, the passed `adapter_name` | ||
| should be the name of an already loaded adapter. | ||
|
|
||
| If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need | ||
| to call an additional method before loading the adapter: | ||
|
|
||
| ```py | ||
| model = AutoModel.from_pretrained(...) | ||
| max_rank = ... # the highest rank among all LoRAs that you want to load | ||
| # call *before* compiling and loading the LoRA adapter | ||
| model.enable_peft_hotswap(target_rank=max_rank) | ||
| model.load_adapter(file_name_1, adapter_name="default") | ||
| # optionally compile the model now | ||
| model = torch.compile(model, ...) | ||
| output_1 = model(...) | ||
| # now you can hotswap the 2nd adapter, use the same name as for the 1st | ||
| model.load_adapter(file_name_2, adapter_name="default", hotswap=True) | ||
| output_2 = model(...) | ||
| ``` | ||
|
|
||
| Note that hotswapping comes with a couple of limitations documented here: | ||
| https://huggingface.co/docs/peft/main/en/package_reference/hotswap | ||
| adapter_kwargs (`dict[str, Any]`, *optional*): | ||
| Additional keyword arguments passed along to the `from_pretrained` method of the adapter config and | ||
| `find_adapter_config_file` method. | ||
| """ | ||
| from peft import PeftType | ||
|
|
||
|
||
| check_peft_version(min_version=MIN_PEFT_VERSION) | ||
|
|
||
| if hotswap: | ||
| min_version_hotswap = "0.15.0" | ||
| if version.parse(importlib.metadata.version("peft")) < version.parse(min_version_hotswap): | ||
| raise ValueError(f"To hotswap the adapter, you need PEFT >= v{min_version_hotswap}.") | ||
| if (not self._hf_peft_config_loaded) or (adapter_name not in self.peft_config): | ||
| raise ValueError( | ||
| "To hotswap an adapter, there must already be an existing adapter with the same adapter name." | ||
| ) | ||
| if any(conf.peft_type != PeftType.LORA for conf in self.peft_config.values()): | ||
| raise ValueError("Hotswapping is currently only supported for LoRA, please set `hotswap=False`.") | ||
|
|
||
|
Comment on lines
+218
to
+220
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just concerned about that, maybe There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This would only crash if the user tries to load non-LoRA and if they:
In either case, the user intent is to use hotswapping. Therefore, I think that raising is the better choice, otherwise the user would think they used hotswapping successfully when it's not actually being used. |
||
| # peft only supports low_cpu_mem_usage starting from v0.13.0 | ||
| peft_load_kwargs = {} | ||
| key_mapping = adapter_kwargs.pop("key_mapping", None) if adapter_kwargs is not None else None | ||
|
|
@@ -181,8 +223,12 @@ def load_adapter( | |
| from peft import PeftConfig, inject_adapter_in_model, load_peft_weights | ||
| from peft.utils import set_peft_model_state_dict | ||
|
|
||
| if self._hf_peft_config_loaded and adapter_name in self.peft_config: | ||
| if self._hf_peft_config_loaded and (not hotswap) and (adapter_name in self.peft_config): | ||
| raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.") | ||
| elif hotswap and ((not self._hf_peft_config_loaded) or (adapter_name not in self.peft_config)): | ||
| raise ValueError( | ||
| "To hotswap an adapter, there must already be an existing adapter with the same adapter name." | ||
| ) | ||
|
|
||
| if peft_model_id is None and (adapter_state_dict is None and peft_config is None): | ||
| raise ValueError( | ||
|
|
@@ -231,8 +277,12 @@ def load_adapter( | |
| ) | ||
| peft_config.inference_mode = not is_trainable | ||
|
|
||
| # Create and add fresh new adapters into the model. | ||
| inject_adapter_in_model(peft_config, self, adapter_name, **peft_load_kwargs) | ||
| if peft_config.peft_type != PeftType.LORA: | ||
| raise ValueError("Hotswapping is currently only supported for LoRA, please set `hotswap=False`.") | ||
|
|
||
| if not hotswap: | ||
| # Create and add fresh new adapters into the model, unless the weights are hotswapped | ||
| inject_adapter_in_model(peft_config, self, adapter_name, **peft_load_kwargs) | ||
|
|
||
| if not self._hf_peft_config_loaded: | ||
| self._hf_peft_config_loaded = True | ||
|
|
@@ -255,12 +305,47 @@ def load_adapter( | |
| # Early exit of the loop | ||
| if n_replace > 0: | ||
| break | ||
|
|
||
| # For hotswapping, we need the adapter name to be present in the state dict keys | ||
| if hotswap: | ||
| if key.endswith("lora_A.weight") or key.endswith("lora_B.weight"): | ||
| new_key = new_key[: -len(".weight")] + f".{adapter_name}.weight" | ||
| elif key.endswith("lora_B.bias"): # lora_bias=True option | ||
| new_key = new_key[: -len(".bias")] + f".{adapter_name}.bias" | ||
| processed_adapter_state_dict[new_key] = value | ||
|
|
||
| # Load state dict | ||
| incompatible_keys = set_peft_model_state_dict( | ||
| self, processed_adapter_state_dict, adapter_name, **peft_load_kwargs | ||
| ) | ||
| if not hotswap: | ||
| incompatible_keys = set_peft_model_state_dict( | ||
| self, processed_adapter_state_dict, adapter_name, **peft_load_kwargs | ||
| ) | ||
|
|
||
| if self._prepare_peft_hotswap_kwargs is not None: | ||
| # For hotswapping of compiled models or adapters with different ranks. | ||
| # If the user called enable_peft_hotswap, we need to ensure it is called: | ||
| # - after the first adapter was loaded | ||
| # - before the model is compiled and the 2nd adapter is being hotswapped in | ||
| # Therefore, it needs to be called here | ||
| from peft.utils.hotswap import prepare_model_for_compiled_hotswap | ||
|
|
||
| prepare_model_for_compiled_hotswap(self, config=peft_config, **self._prepare_peft_hotswap_kwargs) | ||
| # We only want to call prepare_model_for_compiled_hotswap once | ||
| self._prepare_peft_hotswap_kwargs = None | ||
|
Comment on lines
+340
to
+350
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be under the condition There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is, see line 335. |
||
| else: | ||
| from peft.utils.hotswap import check_hotswap_configs_compatible, hotswap_adapter_from_state_dict | ||
|
|
||
| check_hotswap_configs_compatible(self.peft_config[adapter_name], peft_config) | ||
| try: | ||
| hotswap_adapter_from_state_dict( | ||
| model=self, | ||
| state_dict=processed_adapter_state_dict, | ||
| adapter_name=adapter_name, | ||
| config=peft_config, | ||
| ) | ||
| except Exception as e: | ||
| logger.error(f"Hotswapping {adapter_name} was unsucessful with the following error: \n{e}") | ||
| raise | ||
| incompatible_keys = None | ||
|
|
||
| if incompatible_keys is not None: | ||
| err_msg = "" | ||
|
|
@@ -302,6 +387,46 @@ def load_adapter( | |
| offload_index=offload_index, | ||
| ) | ||
|
|
||
| def enable_peft_hotswap( | ||
| self, target_rank: int = 128, check_compiled: Literal["error", "warn", "ignore"] = "error" | ||
| ) -> None: | ||
| """Enables the possibility to hotswap PEFT adapters with different ranks, or, if the model is compiled, without | ||
| triggering recompilation. | ||
|
|
||
| Right now, hotswapping is only supported for LoRA. | ||
|
|
||
| Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of | ||
| the loaded adapters differ. If the ranks are all identical and the model is not compiled, hotswapping works | ||
| without calling this method first. | ||
|
|
||
| Args: | ||
| target_rank (`int`, *optional*, defaults to `128`): | ||
| The highest rank among all the adapters that will be loaded. | ||
| check_compiled (`str`, *optional*, defaults to `"error"`): | ||
| How to handle the case when the model is already compiled, which should generally be avoided. The | ||
| options are: | ||
| - "error" (default): raise an error | ||
| - "warn": issue a warning | ||
| - "ignore": do nothing | ||
| """ | ||
| min_version_hotswap = "0.15.0" | ||
| if version.parse(importlib.metadata.version("peft")) < version.parse(min_version_hotswap): | ||
| raise ValueError(f"To hotswap the adapter, you need PEFT >= v{min_version_hotswap}.") | ||
|
|
||
| if getattr(self, "peft_config", {}): | ||
| if check_compiled == "error": | ||
| raise RuntimeError("Call `enable_peft_hotswap` before loading the first adapter.") | ||
| elif check_compiled == "warn": | ||
| logger.warning( | ||
| "It is recommended to call `enable_peft_hotswap` before loading the first adapter to avoid recompilation." | ||
| ) | ||
| elif check_compiled != "ignore": | ||
| raise ValueError( | ||
| f"check_compiles should be one of 'error', 'warn', or 'ignore', got '{check_compiled}' instead." | ||
| ) | ||
|
|
||
| self._prepare_peft_hotswap_kwargs = {"target_rank": target_rank, "check_compiled": check_compiled} | ||
|
|
||
| def add_adapter(self, adapter_config, adapter_name: Optional[str] = None) -> None: | ||
| r""" | ||
| If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.