Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/source/en/api/loaders/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse

</Tip>

## LoraBaseMixin

[[autodoc]] loaders.lora_base.LoraBaseMixin

## StableDiffusionLoraLoaderMixin

[[autodoc]] loaders.lora_pipeline.StableDiffusionLoraLoaderMixin
Expand Down Expand Up @@ -96,10 +100,6 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse

[[autodoc]] loaders.lora_pipeline.HiDreamImageLoraLoaderMixin

## LoraBaseMixin

[[autodoc]] loaders.lora_base.LoraBaseMixin

## WanLoraLoaderMixin

[[autodoc]] loaders.lora_pipeline.WanLoraLoaderMixin
201 changes: 141 additions & 60 deletions src/diffusers/loaders/lora_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,17 @@ def _load_lora_into_text_encoder(


def _func_optionally_disable_offloading(_pipeline):
"""
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.

Args:
_pipeline (`DiffusionPipeline`):
The pipeline to disable offloading for.

Returns:
tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
"""
is_model_cpu_offload = False
is_sequential_cpu_offload = False

Expand Down Expand Up @@ -453,6 +464,24 @@ class LoraBaseMixin:
_lora_loadable_modules = []
_merged_adapters = set()

@property
def lora_scale(self) -> float:
"""
Returns the lora scale which can be set at run time by the pipeline. # if `_lora_scale` has not been set,
return 1.
"""
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0

@property
def num_fused_loras(self):
"""Returns the number of LoRAs that have been fused."""
return len(self._merged_adapters)

@property
def fused_loras(self):
"""Returns names of the LoRAs that have been fused."""
return self._merged_adapters

def load_lora_weights(self, **kwargs):
raise NotImplementedError("`load_lora_weights()` is not implemented.")

Expand All @@ -464,33 +493,6 @@ def save_lora_weights(cls, **kwargs):
def lora_state_dict(cls, **kwargs):
raise NotImplementedError("`lora_state_dict()` is not implemented.")

@classmethod
def _optionally_disable_offloading(cls, _pipeline):
"""
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.

Args:
_pipeline (`DiffusionPipeline`):
The pipeline to disable offloading for.

Returns:
tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
"""
return _func_optionally_disable_offloading(_pipeline=_pipeline)

@classmethod
def _fetch_state_dict(cls, *args, **kwargs):
deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
deprecate("_fetch_state_dict", "0.35.0", deprecation_message)
return _fetch_state_dict(*args, **kwargs)

@classmethod
def _best_guess_weight_name(cls, *args, **kwargs):
deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
return _best_guess_weight_name(*args, **kwargs)

def unload_lora_weights(self):
"""
Unloads the LoRA parameters.
Expand Down Expand Up @@ -661,19 +663,37 @@ def unfuse_lora(self, components: List[str] = [], **kwargs):
self._merged_adapters = self._merged_adapters - {adapter}
module.unmerge()

@property
def num_fused_loras(self):
return len(self._merged_adapters)

@property
def fused_loras(self):
return self._merged_adapters

def set_adapters(
self,
adapter_names: Union[List[str], str],
adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
):
"""
Set the currently active adapters for use in the pipeline.

Args:
adapter_names (`List[str]` or `str`):
The names of the adapters to use.
adapter_weights (`Union[List[float], float]`, *optional*):
The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the
adapters.

Example:

```py
from diffusers import AutoPipelineForText2Image
import torch

pipeline = AutoPipelineForText2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights(
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
)
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
```
"""
if isinstance(adapter_weights, dict):
components_passed = set(adapter_weights.keys())
lora_components = set(self._lora_loadable_modules)
Expand Down Expand Up @@ -743,6 +763,24 @@ def set_adapters(
set_adapters_for_text_encoder(adapter_names, model, _component_adapter_weights[component])

def disable_lora(self):
"""
Disables the active LoRA layers of the pipeline.

Example:

```py
from diffusers import AutoPipelineForText2Image
import torch

pipeline = AutoPipelineForText2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights(
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
)
pipeline.disable_lora()
```
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")

Expand All @@ -755,6 +793,24 @@ def disable_lora(self):
disable_lora_for_text_encoder(model)

def enable_lora(self):
"""
Enables the active LoRA layers of the pipeline.

Example:

```py
from diffusers import AutoPipelineForText2Image
import torch

pipeline = AutoPipelineForText2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights(
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
)
pipeline.enable_lora()
```
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")

Expand All @@ -768,10 +824,26 @@ def enable_lora(self):

def delete_adapters(self, adapter_names: Union[List[str], str]):
"""
Delete an adapter's LoRA layers from the pipeline.

Args:
Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s).
adapter_names (`Union[List[str], str]`):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
adapter_names (`Union[List[str], str]`):
adapter_names (`List[str]` or `str`):

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this kind of corrections (expanding Union to be an "or" in the docstrings) should be made in a bulk PR. So, skipping it here.

The names of the adapter to delete. Can be a single string or a list of strings
The names of the adapters to delete.

Example:

```py
from diffusers import AutoPipelineForText2Image
import torch

pipeline = AutoPipelineForText2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights(
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic"
)
pipeline.delete_adapters("cinematic")
```
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
Expand Down Expand Up @@ -872,6 +944,24 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device,
adapter_name
].to(device)

def enable_lora_hotswap(self, **kwargs) -> None:
"""
Hotswap adapters without triggering recompilation of a model or if the ranks of the loaded adapters are
different.

Args:
target_rank (`int`):
The highest rank among all the adapters that will be loaded.
check_compiled (`str`, *optional*, defaults to `"error"`):
How to handle a model that is already compiled. The check can return the following messages:
- "error" (default): raise an error
- "warn": issue a warning
- "ignore": do nothing
"""
for key, component in self.components.items():
if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules):
component.enable_lora_hotswap(**kwargs)

@staticmethod
def pack_weights(layers, prefix):
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
Expand All @@ -887,6 +977,7 @@ def write_lora_layers(
safe_serialization: bool,
lora_adapter_metadata: Optional[dict] = None,
):
"""Writes the state dict of the LoRA layers (optionally with metadata) to disk."""
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
Expand Down Expand Up @@ -927,28 +1018,18 @@ def save_function(weights, filename):
save_function(state_dict, save_path)
logger.info(f"Model weights saved in {save_path}")

@property
def lora_scale(self) -> float:
# property function that returns the lora scale which can be set at run time by the pipeline.
# if _lora_scale has not been set, return 1
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0

def enable_lora_hotswap(self, **kwargs) -> None:
"""Enables the possibility to hotswap LoRA adapters.
@classmethod
def _optionally_disable_offloading(cls, _pipeline):
return _func_optionally_disable_offloading(_pipeline=_pipeline)

Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of
the loaded adapters differ.
@classmethod
def _fetch_state_dict(cls, *args, **kwargs):
deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
deprecate("_fetch_state_dict", "0.35.0", deprecation_message)
return _fetch_state_dict(*args, **kwargs)

Args:
target_rank (`int`):
The highest rank among all the adapters that will be loaded.
check_compiled (`str`, *optional*, defaults to `"error"`):
How to handle the case when the model is already compiled, which should generally be avoided. The
options are:
- "error" (default): raise an error
- "warn": issue a warning
- "ignore": do nothing
"""
for key, component in self.components.items():
if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules):
component.enable_lora_hotswap(**kwargs)
@classmethod
def _best_guess_weight_name(cls, *args, **kwargs):
deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
return _best_guess_weight_name(*args, **kwargs)
21 changes: 5 additions & 16 deletions src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,6 @@ class PeftAdapterMixin:
@classmethod
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
def _optionally_disable_offloading(cls, _pipeline):
"""
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.

Args:
_pipeline (`DiffusionPipeline`):
The pipeline to disable offloading for.

Returns:
tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
"""
return _func_optionally_disable_offloading(_pipeline=_pipeline)

def load_lora_adapter(
Expand Down Expand Up @@ -444,7 +433,7 @@ def set_adapters(
weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None,
):
"""
Set the currently active adapters for use in the UNet.
Set the currently active adapters for use in the diffusion network (e.g. unet, transformer, etc.).

Args:
adapter_names (`List[str]` or `str`):
Expand All @@ -466,7 +455,7 @@ def set_adapters(
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
)
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
pipeline.unet.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
```
"""
if not USE_PEFT_BACKEND:
Expand Down Expand Up @@ -714,7 +703,7 @@ def disable_lora(self):
pipeline.load_lora_weights(
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
)
pipeline.disable_lora()
pipeline.unet.disable_lora()
```
"""
if not USE_PEFT_BACKEND:
Expand All @@ -737,7 +726,7 @@ def enable_lora(self):
pipeline.load_lora_weights(
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
)
pipeline.enable_lora()
pipeline.unet.enable_lora()
```
"""
if not USE_PEFT_BACKEND:
Expand All @@ -764,7 +753,7 @@ def delete_adapters(self, adapter_names: Union[List[str], str]):
pipeline.load_lora_weights(
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic"
)
pipeline.delete_adapters("cinematic")
pipeline.unet.delete_adapters("cinematic")
```
"""
if not USE_PEFT_BACKEND:
Expand Down
11 changes: 0 additions & 11 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,17 +394,6 @@ def _process_lora(
@classmethod
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
def _optionally_disable_offloading(cls, _pipeline):
"""
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.

Args:
_pipeline (`DiffusionPipeline`):
The pipeline to disable offloading for.

Returns:
tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
"""
return _func_optionally_disable_offloading(_pipeline=_pipeline)

def save_attn_procs(
Expand Down
Loading