|  | 
| 16 | 16 | import os | 
| 17 | 17 | from functools import partial | 
| 18 | 18 | from pathlib import Path | 
| 19 |  | -from typing import Dict, List, Optional, Union | 
|  | 19 | +from typing import Dict, List, Literal, Optional, Union | 
| 20 | 20 | 
 | 
| 21 | 21 | import safetensors | 
| 22 | 22 | import torch | 
| @@ -144,8 +144,7 @@ def _optionally_disable_offloading(cls, _pipeline): | 
| 144 | 144 |     def load_lora_adapter( | 
| 145 | 145 |         self, pretrained_model_name_or_path_or_dict, prefix="transformer", hotswap: bool = False, **kwargs | 
| 146 | 146 |     ): | 
| 147 |  | -        r""" | 
| 148 |  | -        Loads a LoRA adapter into the underlying model. | 
|  | 147 | +        r"""Loads a LoRA adapter into the underlying model. | 
| 149 | 148 | 
 | 
| 150 | 149 |         Parameters: | 
| 151 | 150 |             pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): | 
| @@ -194,21 +193,21 @@ def load_lora_adapter( | 
| 194 | 193 |                 However, the main advantage of hotswapping is that when the model is compiled with torch.compile, | 
| 195 | 194 |                 loading the new adapter does not require recompilation of the model. | 
| 196 | 195 | 
 | 
| 197 |  | -                If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need | 
| 198 |  | -                to call an additional method before loading the adapter: | 
|  | 196 | +                If the model is compiled, or if the new adapter and the old adapter have different ranks and/or LoRA | 
|  | 197 | +                alphas (i.e. scaling), you need to call an additional method before loading the adapter: | 
| 199 | 198 | 
 | 
| 200 | 199 |                 ```py | 
| 201 |  | -                from peft.utils.hotswap import prepare_model_for_compiled_hotswap | 
| 202 |  | -
 | 
| 203 |  | -                model = ...  # load diffusers model with first LoRA adapter | 
|  | 200 | +                pipeline = ...  # load diffusers pipeline | 
| 204 | 201 |                 max_rank = ...  # the highest rank among all LoRAs that you want to load | 
| 205 |  | -                prepare_model_for_compiled_hotswap(model, target_rank=max_rank)  # call *before* compiling | 
| 206 |  | -                model = torch.compile(model) | 
| 207 |  | -                model.load_lora_adapter(..., hotswap=True)  # now hotswap the 2nd adapter | 
|  | 202 | +                # call *before* compiling and loading the LoRA adapter | 
|  | 203 | +                pipeline.enable_lora_hotswap(target_rank=max_rank) | 
|  | 204 | +                pipeline.load_lora_weights(file_name) | 
|  | 205 | +                # optionally compile the model now | 
| 208 | 206 |                 ``` | 
| 209 | 207 | 
 | 
| 210 | 208 |                 There are some limitations to this technique, which are documented here: | 
| 211 | 209 |                 https://huggingface.co/docs/peft/main/en/package_reference/hotswap | 
|  | 210 | +
 | 
| 212 | 211 |         """ | 
| 213 | 212 |         from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict | 
| 214 | 213 |         from peft.tuners.tuners_utils import BaseTunerLayer | 
| @@ -837,16 +836,35 @@ def delete_adapters(self, adapter_names: Union[List[str], str]): | 
| 837 | 836 |             if hasattr(self, "peft_config"): | 
| 838 | 837 |                 self.peft_config.pop(adapter_name, None) | 
| 839 | 838 | 
 | 
| 840 |  | -    def enable_lora_hotswap(self, target_rank: int) -> None: | 
|  | 839 | +    def enable_lora_hotswap( | 
|  | 840 | +        self, target_rank: int = 128, check_compiled: Literal["error", "warn", "ignore"] = "error" | 
|  | 841 | +    ) -> None: | 
| 841 | 842 |         """Enables the possibility to hotswap LoRA adapters. | 
| 842 | 843 | 
 | 
| 843 | 844 |         Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of | 
| 844 | 845 |         the loaded adapters differ. | 
| 845 | 846 | 
 | 
| 846 | 847 |         Args: | 
| 847 |  | -            target_rank (`int`): | 
|  | 848 | +            target_rank (`int`, *optional*, defaults to `128`): | 
| 848 | 849 |                 The highest rank among all the adapters that will be loaded. | 
|  | 850 | +
 | 
|  | 851 | +            check_correct (`str`, *optional*, defaults to `"error"`): | 
|  | 852 | +                How to handle the case when the model is already compiled, which should generally be avoided. The | 
|  | 853 | +                options are: | 
|  | 854 | +                  - "error" (default): raise an error | 
|  | 855 | +                  - "warn": issue a warning | 
|  | 856 | +                  - "ignore": do nothing | 
| 849 | 857 |         """ | 
| 850 | 858 |         if getattr(self, "peft_config", {}): | 
| 851 |  | -            raise RuntimeError("Call `enable_lora_hotswap` before loading the first adapter.") | 
| 852 |  | -        self._prepare_lora_hotswap_kwargs = {"target_rank": target_rank} | 
|  | 859 | +            if check_compiled == "error": | 
|  | 860 | +                raise RuntimeError("Call `enable_lora_hotswap` before loading the first adapter.") | 
|  | 861 | +            elif check_compiled == "warn": | 
|  | 862 | +                logger.warning( | 
|  | 863 | +                    "It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation." | 
|  | 864 | +                ) | 
|  | 865 | +            elif check_compiled != "ignore": | 
|  | 866 | +                raise ValueError( | 
|  | 867 | +                    f"check_compiles should be one of 'error', 'warn', or 'ignore', got '{check_compiled}' instead." | 
|  | 868 | +                ) | 
|  | 869 | + | 
|  | 870 | +        self._prepare_lora_hotswap_kwargs = {"target_rank": target_rank, "check_compiled": check_compiled} | 
0 commit comments