diff --git a/docs/source/en/api/loaders/lora.md b/docs/source/en/api/loaders/lora.md index 7831f099ae3e..574b8499e1d5 100644 --- a/docs/source/en/api/loaders/lora.md +++ b/docs/source/en/api/loaders/lora.md @@ -37,6 +37,10 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse +## LoraBaseMixin + +[[autodoc]] loaders.lora_base.LoraBaseMixin + ## StableDiffusionLoraLoaderMixin [[autodoc]] loaders.lora_pipeline.StableDiffusionLoraLoaderMixin @@ -96,10 +100,6 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse [[autodoc]] loaders.lora_pipeline.HiDreamImageLoraLoaderMixin -## LoraBaseMixin - -[[autodoc]] loaders.lora_base.LoraBaseMixin - ## WanLoraLoaderMixin [[autodoc]] loaders.lora_pipeline.WanLoraLoaderMixin \ No newline at end of file diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 8b417341ca13..87b5b1a3fa23 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -424,6 +424,17 @@ def _load_lora_into_text_encoder( def _func_optionally_disable_offloading(_pipeline): + """ + Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU. + + Args: + _pipeline (`DiffusionPipeline`): + The pipeline to disable offloading for. + + Returns: + tuple: + A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True. + """ is_model_cpu_offload = False is_sequential_cpu_offload = False @@ -453,6 +464,24 @@ class LoraBaseMixin: _lora_loadable_modules = [] _merged_adapters = set() + @property + def lora_scale(self) -> float: + """ + Returns the lora scale which can be set at run time by the pipeline. # if `_lora_scale` has not been set, + return 1. + """ + return self._lora_scale if hasattr(self, "_lora_scale") else 1.0 + + @property + def num_fused_loras(self): + """Returns the number of LoRAs that have been fused.""" + return len(self._merged_adapters) + + @property + def fused_loras(self): + """Returns names of the LoRAs that have been fused.""" + return self._merged_adapters + def load_lora_weights(self, **kwargs): raise NotImplementedError("`load_lora_weights()` is not implemented.") @@ -464,33 +493,6 @@ def save_lora_weights(cls, **kwargs): def lora_state_dict(cls, **kwargs): raise NotImplementedError("`lora_state_dict()` is not implemented.") - @classmethod - def _optionally_disable_offloading(cls, _pipeline): - """ - Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU. - - Args: - _pipeline (`DiffusionPipeline`): - The pipeline to disable offloading for. - - Returns: - tuple: - A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True. - """ - return _func_optionally_disable_offloading(_pipeline=_pipeline) - - @classmethod - def _fetch_state_dict(cls, *args, **kwargs): - deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`." - deprecate("_fetch_state_dict", "0.35.0", deprecation_message) - return _fetch_state_dict(*args, **kwargs) - - @classmethod - def _best_guess_weight_name(cls, *args, **kwargs): - deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`." - deprecate("_best_guess_weight_name", "0.35.0", deprecation_message) - return _best_guess_weight_name(*args, **kwargs) - def unload_lora_weights(self): """ Unloads the LoRA parameters. @@ -661,19 +663,37 @@ def unfuse_lora(self, components: List[str] = [], **kwargs): self._merged_adapters = self._merged_adapters - {adapter} module.unmerge() - @property - def num_fused_loras(self): - return len(self._merged_adapters) - - @property - def fused_loras(self): - return self._merged_adapters - def set_adapters( self, adapter_names: Union[List[str], str], adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None, ): + """ + Set the currently active adapters for use in the pipeline. + + Args: + adapter_names (`List[str]` or `str`): + The names of the adapters to use. + adapter_weights (`Union[List[float], float]`, *optional*): + The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the + adapters. + + Example: + + ```py + from diffusers import AutoPipelineForText2Image + import torch + + pipeline = AutoPipelineForText2Image.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights( + "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" + ) + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5]) + ``` + """ if isinstance(adapter_weights, dict): components_passed = set(adapter_weights.keys()) lora_components = set(self._lora_loadable_modules) @@ -743,6 +763,24 @@ def set_adapters( set_adapters_for_text_encoder(adapter_names, model, _component_adapter_weights[component]) def disable_lora(self): + """ + Disables the active LoRA layers of the pipeline. + + Example: + + ```py + from diffusers import AutoPipelineForText2Image + import torch + + pipeline = AutoPipelineForText2Image.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights( + "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" + ) + pipeline.disable_lora() + ``` + """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -755,6 +793,24 @@ def disable_lora(self): disable_lora_for_text_encoder(model) def enable_lora(self): + """ + Enables the active LoRA layers of the pipeline. + + Example: + + ```py + from diffusers import AutoPipelineForText2Image + import torch + + pipeline = AutoPipelineForText2Image.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights( + "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" + ) + pipeline.enable_lora() + ``` + """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -768,10 +824,26 @@ def enable_lora(self): def delete_adapters(self, adapter_names: Union[List[str], str]): """ + Delete an adapter's LoRA layers from the pipeline. + Args: - Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s). adapter_names (`Union[List[str], str]`): - The names of the adapter to delete. Can be a single string or a list of strings + The names of the adapters to delete. + + Example: + + ```py + from diffusers import AutoPipelineForText2Image + import torch + + pipeline = AutoPipelineForText2Image.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights( + "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic" + ) + pipeline.delete_adapters("cinematic") + ``` """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -872,6 +944,24 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, adapter_name ].to(device) + def enable_lora_hotswap(self, **kwargs) -> None: + """ + Hotswap adapters without triggering recompilation of a model or if the ranks of the loaded adapters are + different. + + Args: + target_rank (`int`): + The highest rank among all the adapters that will be loaded. + check_compiled (`str`, *optional*, defaults to `"error"`): + How to handle a model that is already compiled. The check can return the following messages: + - "error" (default): raise an error + - "warn": issue a warning + - "ignore": do nothing + """ + for key, component in self.components.items(): + if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules): + component.enable_lora_hotswap(**kwargs) + @staticmethod def pack_weights(layers, prefix): layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers @@ -887,6 +977,7 @@ def write_lora_layers( safe_serialization: bool, lora_adapter_metadata: Optional[dict] = None, ): + """Writes the state dict of the LoRA layers (optionally with metadata) to disk.""" if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return @@ -927,28 +1018,18 @@ def save_function(weights, filename): save_function(state_dict, save_path) logger.info(f"Model weights saved in {save_path}") - @property - def lora_scale(self) -> float: - # property function that returns the lora scale which can be set at run time by the pipeline. - # if _lora_scale has not been set, return 1 - return self._lora_scale if hasattr(self, "_lora_scale") else 1.0 - - def enable_lora_hotswap(self, **kwargs) -> None: - """Enables the possibility to hotswap LoRA adapters. + @classmethod + def _optionally_disable_offloading(cls, _pipeline): + return _func_optionally_disable_offloading(_pipeline=_pipeline) - Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of - the loaded adapters differ. + @classmethod + def _fetch_state_dict(cls, *args, **kwargs): + deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`." + deprecate("_fetch_state_dict", "0.35.0", deprecation_message) + return _fetch_state_dict(*args, **kwargs) - Args: - target_rank (`int`): - The highest rank among all the adapters that will be loaded. - check_compiled (`str`, *optional*, defaults to `"error"`): - How to handle the case when the model is already compiled, which should generally be avoided. The - options are: - - "error" (default): raise an error - - "warn": issue a warning - - "ignore": do nothing - """ - for key, component in self.components.items(): - if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules): - component.enable_lora_hotswap(**kwargs) + @classmethod + def _best_guess_weight_name(cls, *args, **kwargs): + deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`." + deprecate("_best_guess_weight_name", "0.35.0", deprecation_message) + return _best_guess_weight_name(*args, **kwargs) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 3cc3296ce43d..343623071340 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -85,17 +85,6 @@ class PeftAdapterMixin: @classmethod # Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading def _optionally_disable_offloading(cls, _pipeline): - """ - Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU. - - Args: - _pipeline (`DiffusionPipeline`): - The pipeline to disable offloading for. - - Returns: - tuple: - A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True. - """ return _func_optionally_disable_offloading(_pipeline=_pipeline) def load_lora_adapter( @@ -444,7 +433,7 @@ def set_adapters( weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None, ): """ - Set the currently active adapters for use in the UNet. + Set the currently active adapters for use in the diffusion network (e.g. unet, transformer, etc.). Args: adapter_names (`List[str]` or `str`): @@ -466,7 +455,7 @@ def set_adapters( "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" ) pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") - pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5]) + pipeline.unet.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5]) ``` """ if not USE_PEFT_BACKEND: @@ -714,7 +703,7 @@ def disable_lora(self): pipeline.load_lora_weights( "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" ) - pipeline.disable_lora() + pipeline.unet.disable_lora() ``` """ if not USE_PEFT_BACKEND: @@ -737,7 +726,7 @@ def enable_lora(self): pipeline.load_lora_weights( "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" ) - pipeline.enable_lora() + pipeline.unet.enable_lora() ``` """ if not USE_PEFT_BACKEND: @@ -764,7 +753,7 @@ def delete_adapters(self, adapter_names: Union[List[str], str]): pipeline.load_lora_weights( "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic" ) - pipeline.delete_adapters("cinematic") + pipeline.unet.delete_adapters("cinematic") ``` """ if not USE_PEFT_BACKEND: diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 621194df9d07..68be84119177 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -394,17 +394,6 @@ def _process_lora( @classmethod # Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading def _optionally_disable_offloading(cls, _pipeline): - """ - Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU. - - Args: - _pipeline (`DiffusionPipeline`): - The pipeline to disable offloading for. - - Returns: - tuple: - A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True. - """ return _func_optionally_disable_offloading(_pipeline=_pipeline) def save_attn_procs(