diff --git a/docs/source/en/peft.md b/docs/source/en/peft.md index 4ee0e2681963..6b963aeb4bd9 100644 --- a/docs/source/en/peft.md +++ b/docs/source/en/peft.md @@ -151,3 +151,49 @@ model.enable_adapters() # disable all adapters model.disable_adapters() ``` + +## Hotswapping adapters + +A common use case when serving multiple adapters is to load one adapter first, generate output, load another adapter, generate more outputs, load another adapter, etc. This can be inefficient, since each time a new adapter is loaded, new memory is reserved; moreover, if the model is compiled with `torch.compile`, it needs to be re-compiled each time a new adapter is used. When switching frequently, the compilation time may never be amortized. + +To better support this common workflow, you can "hotswap" a LoRA adapter, to avoid accumulating memory and, in some cases, recompilation. It requires an adapter to already be loaded, and the new adapter weights are swapped in-place for the existing adapter. Note that other PEFT methods are not supported yet, only LoRA. + +Pass `hotswap=True` when loading a LoRA adapter to enable this feature. It is important to indicate the name of the existing adapter (`"default"` is the default adapter name) to be swapped. + +```python +model = AutoModel.from_pretrained(...) +# load adapter 1 as normal +model.load_adapter(file_name_adapter_1) +# generate outputs with adapter 1 +... +# now hotswap the 2nd adapter +model.load_adapter(file_name_adapter_2, hotswap=True, adapter_name="default") +# generate outputs with adapter 2 +``` + +For compiled models, it is often necessary to call [`~integrations.peft.PeftAdapterMixin.enable_peft_hotswap`] to avoid recompilation. Call this method _before_ loading the first adapter, while `torch.compile` should be called _after_ loading the first adapter. + +```python +model = AutoModel.from_pretrained(...) +max_rank = ... # the highest rank among all LoRAs that you want to load +# call *before* compiling and loading the LoRA adapter +model.enable_peft_hotswap(target_rank=max_rank) +model.load_adapter(file_name_1, adapter_name="default") +# optionally compile the model now +model = torch.compile(model, ...) +output_1 = model(...) +# now you can hotswap the 2nd adapter, use the same name as for the 1st +model.load_adapter(file_name_2, adapter_name="default") +output_2 = model(...) +``` + +The `target_rank=max_rank` argument is important for setting the maximum rank among all LoRA adapters that will be loaded. If you have one adapter with rank 8 and another with rank 16, pass `target_rank=16`. You should use a higher value if in doubt. By default, this value is 128. + +By default, hotswapping is disabled and requires you to pass `hotswap=True` to `load_adapter`. However, if you called `enable_peft_hotswap` first, hotswapping will be enabled by default. If you want to avoid using it, you need to pass `hotswap=False`. + +However, there can be situations where recompilation is unavoidable. For example, if the hotswapped adapter targets more layers than the initial adapter, then recompilation is triggered. Try to load the adapter that targets the most layers first. Refer to the PEFT docs on [hotswapping](https://huggingface.co/docs/peft/main/en/package_reference/hotswap#peft.utils.hotswap.hotswap_adapter) for more details about the limitations of this feature. + +> [!Tip] +> Move your code inside the `with torch._dynamo.config.patch(error_on_recompile=True)` context manager to detect if a model was recompiled. If you detect recompilation despite following all the steps above, please open an issue with [PEFT](https://github.com/huggingface/peft/issues) with a reproducible example. + +For an example of how the use of `torch.compile` in combination with hotswapping can improve runtime, check out [this blogpost](https://huggingface.co/blog/lora-fast). Although that example uses Diffusers, similar improvements can be expected here. diff --git a/src/transformers/integrations/peft.py b/src/transformers/integrations/peft.py index 3198cff77146..a35cc7d0a435 100644 --- a/src/transformers/integrations/peft.py +++ b/src/transformers/integrations/peft.py @@ -17,7 +17,7 @@ import json import os import re -from typing import Any, Optional, Union +from typing import Any, Literal, Optional, Union from packaging import version @@ -89,6 +89,7 @@ class PeftAdapterMixin: """ _hf_peft_config_loaded = False + _prepare_peft_hotswap_kwargs: Optional[dict] = None def load_adapter( self, @@ -104,6 +105,7 @@ def load_adapter( adapter_state_dict: Optional[dict[str, "torch.Tensor"]] = None, low_cpu_mem_usage: bool = False, is_trainable: bool = False, + hotswap: bool | Literal["auto"] = "auto", adapter_kwargs: Optional[dict[str, Any]] = None, ) -> None: """ @@ -159,12 +161,63 @@ def load_adapter( is_trainable (`bool`, *optional*, defaults to `False`): Whether the adapter should be trainable or not. If `False`, the adapter will be frozen and can only be used for inference. + hotswap : (`"auto"` or `bool`, *optional*, defaults to `"auto"`) + Whether to substitute an existing (LoRA) adapter with the newly loaded adapter in-place. This means + that, instead of loading an additional adapter, this will take the existing adapter weights and replace + them with the weights of the new adapter. This can be faster and more memory efficient. However, the + main advantage of hotswapping is that when the model is compiled with torch.compile, loading the new + adapter does not require recompilation of the model. When using hotswapping, the passed `adapter_name` + should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + model = AutoModel.from_pretrained(...) + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + model.enable_peft_hotswap(target_rank=max_rank) + model.load_adapter(file_name_1, adapter_name="default") + # optionally compile the model now + model = torch.compile(model, ...) + output_1 = model(...) + # now you can hotswap the 2nd adapter, use the same name as for the 1st + # hotswap is activated by default since enable_peft_hotswap was called + model.load_adapter(file_name_2, adapter_name="default") + output_2 = model(...) + ``` + + By default, hotswap is disabled and requires passing `hotswap=True`. If you called + `enable_peft_hotswap` first, it is enabled. You can still manually disable it in that case by passing + `hotswap=False`. + + Note that hotswapping comes with a couple of limitations documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap adapter_kwargs (`dict[str, Any]`, *optional*): Additional keyword arguments passed along to the `from_pretrained` method of the adapter config and `find_adapter_config_file` method. """ check_peft_version(min_version=MIN_PEFT_VERSION) + from peft import PeftType + + if hotswap == "auto": + # if user called model.enable_peft_hotswap and this is not the first adapter, enable hotswap + hotswap_enabled = getattr(self, "_hotswap_enabled", False) + not_first_adapter = bool(self._hf_peft_config_loaded and (adapter_name in self.peft_config)) + hotswap = hotswap_enabled and not_first_adapter + + if hotswap: + min_version_hotswap = "0.15.0" + if version.parse(importlib.metadata.version("peft")) < version.parse(min_version_hotswap): + raise ValueError(f"To hotswap the adapter, you need PEFT >= v{min_version_hotswap}.") + if (not self._hf_peft_config_loaded) or (adapter_name not in self.peft_config): + raise ValueError( + "To hotswap an adapter, there must already be an existing adapter with the same adapter name." + ) + if any(conf.peft_type != PeftType.LORA for conf in self.peft_config.values()): + raise ValueError("Hotswapping is currently only supported for LoRA, please set `hotswap=False`.") + # peft only supports low_cpu_mem_usage starting from v0.13.0 peft_load_kwargs = {} key_mapping = adapter_kwargs.pop("key_mapping", None) if adapter_kwargs is not None else None @@ -187,8 +240,12 @@ def load_adapter( from peft import PeftConfig, inject_adapter_in_model, load_peft_weights from peft.utils import set_peft_model_state_dict - if self._hf_peft_config_loaded and adapter_name in self.peft_config: + if self._hf_peft_config_loaded and (not hotswap) and (adapter_name in self.peft_config): raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.") + elif hotswap and ((not self._hf_peft_config_loaded) or (adapter_name not in self.peft_config)): + raise ValueError( + "To hotswap an adapter, there must already be an existing adapter with the same adapter name." + ) if peft_model_id is None and (adapter_state_dict is None and peft_config is None): raise ValueError( @@ -237,8 +294,12 @@ def load_adapter( ) peft_config.inference_mode = not is_trainable - # Create and add fresh new adapters into the model. - inject_adapter_in_model(peft_config, self, adapter_name, **peft_load_kwargs) + if peft_config.peft_type != PeftType.LORA: + raise ValueError("Hotswapping is currently only supported for LoRA, please set `hotswap=False`.") + + if not hotswap: + # Create and add fresh new adapters into the model, unless the weights are hotswapped + inject_adapter_in_model(peft_config, self, adapter_name, **peft_load_kwargs) if not self._hf_peft_config_loaded: self._hf_peft_config_loaded = True @@ -261,12 +322,47 @@ def load_adapter( # Early exit of the loop if n_replace > 0: break + + # For hotswapping, we need the adapter name to be present in the state dict keys + if hotswap: + if key.endswith("lora_A.weight") or key.endswith("lora_B.weight"): + new_key = new_key[: -len(".weight")] + f".{adapter_name}.weight" + elif key.endswith("lora_B.bias"): # lora_bias=True option + new_key = new_key[: -len(".bias")] + f".{adapter_name}.bias" processed_adapter_state_dict[new_key] = value # Load state dict - incompatible_keys = set_peft_model_state_dict( - self, processed_adapter_state_dict, adapter_name, **peft_load_kwargs - ) + if not hotswap: + incompatible_keys = set_peft_model_state_dict( + self, processed_adapter_state_dict, adapter_name, **peft_load_kwargs + ) + + if self._prepare_peft_hotswap_kwargs is not None: + # For hotswapping of compiled models or adapters with different ranks. + # If the user called enable_peft_hotswap, we need to ensure it is called: + # - after the first adapter was loaded + # - before the model is compiled and the 2nd adapter is being hotswapped in + # Therefore, it needs to be called here + from peft.utils.hotswap import prepare_model_for_compiled_hotswap + + prepare_model_for_compiled_hotswap(self, config=peft_config, **self._prepare_peft_hotswap_kwargs) + # We only want to call prepare_model_for_compiled_hotswap once + self._prepare_peft_hotswap_kwargs = None + else: + from peft.utils.hotswap import check_hotswap_configs_compatible, hotswap_adapter_from_state_dict + + check_hotswap_configs_compatible(self.peft_config[adapter_name], peft_config) + try: + hotswap_adapter_from_state_dict( + model=self, + state_dict=processed_adapter_state_dict, + adapter_name=adapter_name, + config=peft_config, + ) + except Exception as e: + logger.error(f"Hotswapping {adapter_name} was unsucessful with the following error: \n{e}") + raise + incompatible_keys = None if incompatible_keys is not None: err_msg = "" @@ -308,6 +404,47 @@ def load_adapter( offload_index=offload_index, ) + def enable_peft_hotswap( + self, target_rank: int = 128, check_compiled: Literal["error", "warn", "ignore"] = "error" + ) -> None: + """Enables the possibility to hotswap PEFT adapters with different ranks, or, if the model is compiled, without + triggering recompilation. + + Right now, hotswapping is only supported for LoRA. + + Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of + the loaded adapters differ. If the ranks are all identical and the model is not compiled, hotswapping works + without calling this method first. + + Args: + target_rank (`int`, *optional*, defaults to `128`): + The highest rank among all the adapters that will be loaded. + check_compiled (`str`, *optional*, defaults to `"error"`): + How to handle the case when the model is already compiled, which should generally be avoided. The + options are: + - "error" (default): raise an error + - "warn": issue a warning + - "ignore": do nothing + """ + min_version_hotswap = "0.15.0" + if version.parse(importlib.metadata.version("peft")) < version.parse(min_version_hotswap): + raise ValueError(f"To hotswap the adapter, you need PEFT >= v{min_version_hotswap}.") + + if getattr(self, "peft_config", {}): + if check_compiled == "error": + raise RuntimeError("Call `enable_peft_hotswap` before loading the first adapter.") + elif check_compiled == "warn": + logger.warning( + "It is recommended to call `enable_peft_hotswap` before loading the first adapter to avoid recompilation." + ) + elif check_compiled != "ignore": + raise ValueError( + f"check_compiles should be one of 'error', 'warn', or 'ignore', got '{check_compiled}' instead." + ) + + self._hotswap_enabled = True + self._prepare_peft_hotswap_kwargs = {"target_rank": target_rank, "check_compiled": check_compiled} + def add_adapter(self, adapter_config, adapter_name: Optional[str] = None) -> None: r""" If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT diff --git a/tests/peft_integration/test_peft_integration.py b/tests/peft_integration/test_peft_integration.py index ad0978164043..21ee465ae6cf 100644 --- a/tests/peft_integration/test_peft_integration.py +++ b/tests/peft_integration/test_peft_integration.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import gc import importlib import os import re @@ -940,3 +941,191 @@ def test_peft_pipeline_no_warning(self): # Generate text to verify pipeline works _ = lora_generator(text, max_new_tokens=20) + + +@require_peft +@require_torch +@slow +class PeftHotswapIntegrationTest(unittest.TestCase): + def tearDown(self): + # It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model, + # there will be recompilation errors, as torch caches the model when run in the same process. + torch.compiler.reset() + gc.collect() + + def _check_model_hotswap(self, *, rank1, rank2, do_compile): + # utility method that checks that we can successfully hotswap adapters, with the model outputs corresponding to + # the respective adapters + from peft import LoraConfig + + torch.manual_seed(0) + model_id = "hf-internal-testing/tiny-random-OPTForCausalLM" + model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device) + input = torch.randint(0, 100, (1, 10)).to(torch_device) + with torch.inference_mode(): + base_output = model(input).logits + + # create 2 adapters + model.add_adapter(LoraConfig(r=rank1, init_lora_weights=False), adapter_name="adapter_1") + with torch.inference_mode(): + lora_1_output = model(input).logits + + # second adapter may have a different rank + model.add_adapter(LoraConfig(r=rank2, init_lora_weights=False), adapter_name="adapter_2") + model.set_adapter("adapter_2") + with torch.inference_mode(): + lora_2_output = model(input).logits + + # sanity checks + self.assertFalse(torch.allclose(base_output, lora_1_output, atol=1e-6, rtol=1e-6)) + self.assertFalse(torch.allclose(base_output, lora_2_output, atol=1e-6, rtol=1e-6)) + self.assertFalse(torch.allclose(lora_1_output, lora_2_output, atol=1e-6, rtol=1e-6)) + + with tempfile.TemporaryDirectory() as tmpdirname: + path_1 = os.path.join(tmpdirname, "adapter_1") + path_2 = os.path.join(tmpdirname, "adapter_2") + model.set_adapter("adapter_1") + model.save_pretrained(path_1) + model.set_adapter("adapter_2") + model.save_pretrained(path_2) + del model + + model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device) + enable_hotswap = do_compile or (rank1 != rank2) + if enable_hotswap: + # calling this is only needed if we want to compile the model or if the ranks are different + model.enable_peft_hotswap(target_rank=max(rank1, rank2)) + + # load the first adapter without hotswap (hotswap requires an existing adapter) + model.load_adapter(path_1, adapter_name="adapter_1") + if do_compile: + # compile the model after loading the first adapter + model = torch.compile(model, mode="reduce-overhead") + + with torch.inference_mode(): + lora_1_output_loaded = model(input).logits + self.assertTrue(torch.allclose(lora_1_output, lora_1_output_loaded, atol=1e-6, rtol=1e-6)) + + # hotswap in adapter_2 again, output should be same as lora_2_output + if enable_hotswap: + # after calling enable_peft_hotswap, hotswap will automatically be enabled + model.load_adapter(path_2, adapter_name="adapter_1") + else: + # enable_peft_hotswap was not called, need to explicitly pass hotswap=True + model.load_adapter(path_2, adapter_name="adapter_1", hotswap=True) + + with torch.inference_mode(): + lora_2_output_loaded = model(input).logits + self.assertTrue(torch.allclose(lora_2_output, lora_2_output_loaded, atol=1e-6, rtol=1e-6)) + + def test_hotswap_wrong_peft_type_raises(self): + # only LoRA is supported for now + from peft import IA3Config + + model_id = "hf-internal-testing/tiny-random-OPTForCausalLM" + peft_id = "peft-internal-testing/tiny-OPTForCausalLM-lora" + model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device) + peft_config = IA3Config(feedforward_modules=[]) + model.add_adapter(peft_config, adapter_name="ia3") + + msg = "Hotswapping is currently only supported for LoRA" + with self.assertRaisesRegex(ValueError, msg): + model.load_adapter(peft_id, adapter_name="ia3", hotswap=True) + + def test_hotswap_without_existing_adapter_raises(self): + # we can only hotswap if there is already an adapter with the same name + model_id = "hf-internal-testing/tiny-random-OPTForCausalLM" + peft_id = "peft-internal-testing/tiny-OPTForCausalLM-lora" + model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device) + + msg = "To hotswap an adapter, there must already be an existing adapter with the same adapter name" + with self.assertRaisesRegex(ValueError, msg): + model.load_adapter(peft_id, adapter_name="adapter_1", hotswap=True) + + def test_hotswap_different_adapter_name_raises(self): + # we can only hotswap if there is already an adapter with the same name + model_id = "hf-internal-testing/tiny-random-OPTForCausalLM" + peft_id = "peft-internal-testing/tiny-OPTForCausalLM-lora" + model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device) + model.load_adapter(peft_id, adapter_name="adapter_1") + + other_name = "does_not_exist_yet" + msg = "To hotswap an adapter, there must already be an existing adapter with the same adapter name" + with self.assertRaisesRegex(ValueError, msg): + model.load_adapter(peft_id, adapter_name=other_name, hotswap=True) + + def test_enable_peft_hotswap_called_after_adapter_added_raises(self): + # ensure that when enable_peft_hotswap is called *after* loading the first adapter, an error is raised + from peft import LoraConfig + + model_id = "hf-internal-testing/tiny-random-OPTForCausalLM" + model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device) + lora_config = LoraConfig() + model.add_adapter(lora_config) + msg = re.escape("Call `enable_peft_hotswap` before loading the first adapter.") + + with self.assertRaisesRegex(RuntimeError, msg): + model.enable_peft_hotswap(target_rank=32) + + def test_enable_peft_hotswap_called_after_adapter_added_warns(self): + # ensure that when enable_peft_hotswap is called *after* loading the first adapter, there is a warning if + # check_compiled="warn" + from peft import LoraConfig + + logger = logging.get_logger("transformers.integrations.peft") + model_id = "hf-internal-testing/tiny-random-OPTForCausalLM" + model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device) + lora_config = LoraConfig() + model.add_adapter(lora_config) + msg = "It is recommended to call `enable_peft_hotswap` before loading the first adapter to avoid recompilation" + + with self.assertLogs(logger=logger, level="WARNING") as cm: + model.enable_peft_hotswap(target_rank=32, check_compiled="warn") + assert any(msg in log for log in cm.output) + + def test_enable_peft_hotswap_called_after_adapter_added_ignored(self): + # Ensure that when enable_peft_hotswap is called *after* loading the first adapter, there is no error or + # warning if check_compiled="ignore". Note that assertNoLogs only works with Python 3.10+. + from peft import LoraConfig + + logger = logging.get_logger("transformers.integrations.peft") + model_id = "hf-internal-testing/tiny-random-OPTForCausalLM" + model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device) + lora_config = LoraConfig() + model.add_adapter(lora_config) + + with self.assertNoLogs(logger, level="WARNING"): + model.enable_peft_hotswap(target_rank=32, check_compiled="ignore") + + def test_hotswap_without_compile_and_same_ranks_works(self): + self._check_model_hotswap(rank1=8, rank2=8, do_compile=False) + + def test_hotswap_without_compile_and_with_lower_rank_works(self): + self._check_model_hotswap(rank1=13, rank2=7, do_compile=False) + + def test_hotswap_without_compile_and_with_higher_rank_works(self): + self._check_model_hotswap(rank1=7, rank2=13, do_compile=False) + + def test_hotswap_with_compile_and_same_ranks_works(self): + # It's important to add this context to raise an error on recompilation + with ( + torch._dynamo.config.patch(error_on_recompile=True), + torch._inductor.utils.fresh_inductor_cache(), + ): + self._check_model_hotswap(rank1=8, rank2=8, do_compile=True) + + def test_hotswap_with_compile_and_lower_rank_works(self): + # It's important to add this context to raise an error on recompilation + with ( + torch._dynamo.config.patch(error_on_recompile=True), + torch._inductor.utils.fresh_inductor_cache(), + ): + self._check_model_hotswap(rank1=13, rank2=7, do_compile=True) + + def test_hotswap_with_compile_and_higher_rank_works(self): + # It's important to add this context to raise an error on recompilation + with ( + torch._dynamo.config.patch(error_on_recompile=True), + torch._inductor.utils.fresh_inductor_cache(), + ): + self._check_model_hotswap(rank1=7, rank2=13, do_compile=True)