Skip to content
46 changes: 46 additions & 0 deletions docs/source/en/peft.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,49 @@ model.enable_adapters()
# disable all adapters
model.disable_adapters()
```

## Hotswapping adapters

A common use case when serving multiple adapters is to load one adapter first, generate output, load another adapter, generate more outputs, load another adapter, etc. This can be inefficient, since each time a new adapter is loaded, new memory is reserved; moreover, if the model is compiled with `torch.compile`, it needs to be re-compiled each time a new adapter is used. When switching frequently, the compilation time may never be amortized.

To better support this common workflow, you can "hotswap" a LoRA adapter, to avoid accumulating memory and, in some cases, recompilation. It requires an adapter to already be loaded, and the new adapter weights are swapped in-place for the existing adapter. Note that other PEFT methods are not supported yet, only LoRA.

Pass `hotswap=True` when loading a LoRA adapter to enable this feature. It is important to indicate the name of the existing adapter (`"default"` is the default adapter name) to be swapped.

```python
model = AutoModel.from_pretrained(...)
# load adapter 1 as normal
model.load_adapter(file_name_adapter_1)
# generate outputs with adapter 1
...
# now hotswap the 2nd adapter
model.load_adapter(file_name_adapter_2, hotswap=True, adapter_name="default")
# generate outputs with adapter 2
```

For compiled models, it is often necessary to call [`~integrations.peft.PeftAdapterMixin.enable_peft_hotswap`] to avoid recompilation. Call this method _before_ loading the first adapter, while `torch.compile` should be called _after_ loading the first adapter.

```python
model = AutoModel.from_pretrained(...)
max_rank = ... # the highest rank among all LoRAs that you want to load
# call *before* compiling and loading the LoRA adapter
model.enable_peft_hotswap(target_rank=max_rank)
model.load_adapter(file_name_1, adapter_name="default")
# optionally compile the model now
model = torch.compile(model, ...)
output_1 = model(...)
# now you can hotswap the 2nd adapter, use the same name as for the 1st
model.load_adapter(file_name_2, adapter_name="default")
output_2 = model(...)
```

The `target_rank=max_rank` argument is important for setting the maximum rank among all LoRA adapters that will be loaded. If you have one adapter with rank 8 and another with rank 16, pass `target_rank=16`. You should use a higher value if in doubt. By default, this value is 128.

By default, hotswapping is disabled and requires you to pass `hotswap=True` to `load_adapter`. However, if you called `enable_peft_hotswap` first, hotswapping will be enabled by default. If you want to avoid using it, you need to pass `hotswap=False`.

However, there can be situations where recompilation is unavoidable. For example, if the hotswapped adapter targets more layers than the initial adapter, then recompilation is triggered. Try to load the adapter that targets the most layers first. Refer to the PEFT docs on [hotswapping](https://huggingface.co/docs/peft/main/en/package_reference/hotswap#peft.utils.hotswap.hotswap_adapter) for more details about the limitations of this feature.

> [!Tip]
> Move your code inside the `with torch._dynamo.config.patch(error_on_recompile=True)` context manager to detect if a model was recompiled. If you detect recompilation despite following all the steps above, please open an issue with [PEFT](https://github.com/huggingface/peft/issues) with a reproducible example.

For an example of how the use of `torch.compile` in combination with hotswapping can improve runtime, check out [this blogpost](https://huggingface.co/blog/lora-fast). Although that example uses Diffusers, similar improvements can be expected here.
151 changes: 144 additions & 7 deletions src/transformers/integrations/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import json
import os
import re
from typing import Any, Optional, Union
from typing import Any, Literal, Optional, Union

from packaging import version

Expand Down Expand Up @@ -89,6 +89,7 @@ class PeftAdapterMixin:
"""

_hf_peft_config_loaded = False
_prepare_peft_hotswap_kwargs: Optional[dict] = None

def load_adapter(
self,
Expand All @@ -104,6 +105,7 @@ def load_adapter(
adapter_state_dict: Optional[dict[str, "torch.Tensor"]] = None,
low_cpu_mem_usage: bool = False,
is_trainable: bool = False,
hotswap: bool | Literal["auto"] = "auto",
adapter_kwargs: Optional[dict[str, Any]] = None,
) -> None:
"""
Expand Down Expand Up @@ -159,12 +161,63 @@ def load_adapter(
is_trainable (`bool`, *optional*, defaults to `False`):
Whether the adapter should be trainable or not. If `False`, the adapter will be frozen and can only be
used for inference.
hotswap : (`"auto"` or `bool`, *optional*, defaults to `"auto"`)
Whether to substitute an existing (LoRA) adapter with the newly loaded adapter in-place. This means
that, instead of loading an additional adapter, this will take the existing adapter weights and replace
them with the weights of the new adapter. This can be faster and more memory efficient. However, the
main advantage of hotswapping is that when the model is compiled with torch.compile, loading the new
adapter does not require recompilation of the model. When using hotswapping, the passed `adapter_name`
should be the name of an already loaded adapter.

If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
to call an additional method before loading the adapter:

```py
model = AutoModel.from_pretrained(...)
max_rank = ... # the highest rank among all LoRAs that you want to load
# call *before* compiling and loading the LoRA adapter
model.enable_peft_hotswap(target_rank=max_rank)
model.load_adapter(file_name_1, adapter_name="default")
# optionally compile the model now
model = torch.compile(model, ...)
output_1 = model(...)
# now you can hotswap the 2nd adapter, use the same name as for the 1st
# hotswap is activated by default since enable_peft_hotswap was called
model.load_adapter(file_name_2, adapter_name="default")
output_2 = model(...)
```

By default, hotswap is disabled and requires passing `hotswap=True`. If you called
`enable_peft_hotswap` first, it is enabled. You can still manually disable it in that case by passing
`hotswap=False`.

Note that hotswapping comes with a couple of limitations documented here:
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
adapter_kwargs (`dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the `from_pretrained` method of the adapter config and
`find_adapter_config_file` method.
"""
check_peft_version(min_version=MIN_PEFT_VERSION)

from peft import PeftType

if hotswap == "auto":
# if user called model.enable_peft_hotswap and this is not the first adapter, enable hotswap
hotswap_enabled = getattr(self, "_hotswap_enabled", False)
not_first_adapter = bool(self._hf_peft_config_loaded and (adapter_name in self.peft_config))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not need bool casting here, do we?

Suggested change
not_first_adapter = bool(self._hf_peft_config_loaded and (adapter_name in self.peft_config))
not_first_adapter = self._hf_peft_config_loaded and adapter_name in self.peft_config

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now, _hf_peft_config_loaded is a bool, so coercing is not needed. But if its type changes in the future, the type of not_first_adapter could also change, so the cast makes this line future proof. Just an example:

>>> _hf_peft_config_loaded = {"foo": 1}
>>> adapter_name = "default"
>>> peft_config = {"default": 2}
>>> not_first_adapter = _hf_peft_config_loaded and adapter_name in peft_config
>>> not_first_adapter
True
>>> _hf_peft_config_loaded = {}  # falsy value short-circuits the conditional
>>> not_first_adapter = _hf_peft_config_loaded and adapter_name in peft_config
>>> not_first_adapter
{}

As you can see, in the last line, we suddenly have a different type for not_first_adapter.

hotswap = hotswap_enabled and not_first_adapter

if hotswap:
min_version_hotswap = "0.15.0"
if version.parse(importlib.metadata.version("peft")) < version.parse(min_version_hotswap):
raise ValueError(f"To hotswap the adapter, you need PEFT >= v{min_version_hotswap}.")
if (not self._hf_peft_config_loaded) or (adapter_name not in self.peft_config):
raise ValueError(
"To hotswap an adapter, there must already be an existing adapter with the same adapter name."
)
if any(conf.peft_type != PeftType.LORA for conf in self.peft_config.values()):
raise ValueError("Hotswapping is currently only supported for LoRA, please set `hotswap=False`.")

Comment on lines +218 to +220
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just concerned about that, maybe False would be a better default no? Here looks like current code would crash without reason when loading more adapters that are not lora
Or maybe a None value, that becomes auto if we use lora, and False otherwise?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would only crash if the user tries to load non-LoRA and if they:

  • passed hotswap=True or
  • called enable_peft_hotswap

In either case, the user intent is to use hotswapping. Therefore, I think that raising is the better choice, otherwise the user would think they used hotswapping successfully when it's not actually being used.

# peft only supports low_cpu_mem_usage starting from v0.13.0
peft_load_kwargs = {}
key_mapping = adapter_kwargs.pop("key_mapping", None) if adapter_kwargs is not None else None
Expand All @@ -187,8 +240,12 @@ def load_adapter(
from peft import PeftConfig, inject_adapter_in_model, load_peft_weights
from peft.utils import set_peft_model_state_dict

if self._hf_peft_config_loaded and adapter_name in self.peft_config:
if self._hf_peft_config_loaded and (not hotswap) and (adapter_name in self.peft_config):
raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.")
elif hotswap and ((not self._hf_peft_config_loaded) or (adapter_name not in self.peft_config)):
raise ValueError(
"To hotswap an adapter, there must already be an existing adapter with the same adapter name."
)

if peft_model_id is None and (adapter_state_dict is None and peft_config is None):
raise ValueError(
Expand Down Expand Up @@ -237,8 +294,12 @@ def load_adapter(
)
peft_config.inference_mode = not is_trainable

# Create and add fresh new adapters into the model.
inject_adapter_in_model(peft_config, self, adapter_name, **peft_load_kwargs)
if peft_config.peft_type != PeftType.LORA:
raise ValueError("Hotswapping is currently only supported for LoRA, please set `hotswap=False`.")

if not hotswap:
# Create and add fresh new adapters into the model, unless the weights are hotswapped
inject_adapter_in_model(peft_config, self, adapter_name, **peft_load_kwargs)

if not self._hf_peft_config_loaded:
self._hf_peft_config_loaded = True
Expand All @@ -261,12 +322,47 @@ def load_adapter(
# Early exit of the loop
if n_replace > 0:
break

# For hotswapping, we need the adapter name to be present in the state dict keys
if hotswap:
if key.endswith("lora_A.weight") or key.endswith("lora_B.weight"):
new_key = new_key[: -len(".weight")] + f".{adapter_name}.weight"
elif key.endswith("lora_B.bias"): # lora_bias=True option
new_key = new_key[: -len(".bias")] + f".{adapter_name}.bias"
processed_adapter_state_dict[new_key] = value

# Load state dict
incompatible_keys = set_peft_model_state_dict(
self, processed_adapter_state_dict, adapter_name, **peft_load_kwargs
)
if not hotswap:
incompatible_keys = set_peft_model_state_dict(
self, processed_adapter_state_dict, adapter_name, **peft_load_kwargs
)

if self._prepare_peft_hotswap_kwargs is not None:
# For hotswapping of compiled models or adapters with different ranks.
# If the user called enable_peft_hotswap, we need to ensure it is called:
# - after the first adapter was loaded
# - before the model is compiled and the 2nd adapter is being hotswapped in
# Therefore, it needs to be called here
from peft.utils.hotswap import prepare_model_for_compiled_hotswap

prepare_model_for_compiled_hotswap(self, config=peft_config, **self._prepare_peft_hotswap_kwargs)
# We only want to call prepare_model_for_compiled_hotswap once
self._prepare_peft_hotswap_kwargs = None
Comment on lines +340 to +350
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be under the condition if not hotswap? 🤔

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is, see line 335.

else:
from peft.utils.hotswap import check_hotswap_configs_compatible, hotswap_adapter_from_state_dict

check_hotswap_configs_compatible(self.peft_config[adapter_name], peft_config)
try:
hotswap_adapter_from_state_dict(
model=self,
state_dict=processed_adapter_state_dict,
adapter_name=adapter_name,
config=peft_config,
)
except Exception as e:
logger.error(f"Hotswapping {adapter_name} was unsucessful with the following error: \n{e}")
raise
incompatible_keys = None

if incompatible_keys is not None:
err_msg = ""
Expand Down Expand Up @@ -308,6 +404,47 @@ def load_adapter(
offload_index=offload_index,
)

def enable_peft_hotswap(
self, target_rank: int = 128, check_compiled: Literal["error", "warn", "ignore"] = "error"
) -> None:
"""Enables the possibility to hotswap PEFT adapters with different ranks, or, if the model is compiled, without
triggering recompilation.

Right now, hotswapping is only supported for LoRA.

Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of
the loaded adapters differ. If the ranks are all identical and the model is not compiled, hotswapping works
without calling this method first.

Args:
target_rank (`int`, *optional*, defaults to `128`):
The highest rank among all the adapters that will be loaded.
check_compiled (`str`, *optional*, defaults to `"error"`):
How to handle the case when the model is already compiled, which should generally be avoided. The
options are:
- "error" (default): raise an error
- "warn": issue a warning
- "ignore": do nothing
"""
min_version_hotswap = "0.15.0"
if version.parse(importlib.metadata.version("peft")) < version.parse(min_version_hotswap):
raise ValueError(f"To hotswap the adapter, you need PEFT >= v{min_version_hotswap}.")

if getattr(self, "peft_config", {}):
if check_compiled == "error":
raise RuntimeError("Call `enable_peft_hotswap` before loading the first adapter.")
elif check_compiled == "warn":
logger.warning(
"It is recommended to call `enable_peft_hotswap` before loading the first adapter to avoid recompilation."
)
elif check_compiled != "ignore":
raise ValueError(
f"check_compiles should be one of 'error', 'warn', or 'ignore', got '{check_compiled}' instead."
)

self._hotswap_enabled = True
self._prepare_peft_hotswap_kwargs = {"target_rank": target_rank, "check_compiled": check_compiled}

def add_adapter(self, adapter_config, adapter_name: Optional[str] = None) -> None:
r"""
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
Expand Down
Loading