Skip to content
Merged
Show file tree
Hide file tree
Changes from 53 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
d3fbd7b
[WIP][LoRA] Implement hot-swapping of LoRA
BenjaminBossan Sep 17, 2024
84bae62
Reviewer feedback
BenjaminBossan Sep 18, 2024
63ece9d
Reviewer feedback, adjust test
BenjaminBossan Oct 16, 2024
94c669c
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Oct 16, 2024
c7378ed
Fix, doc
BenjaminBossan Oct 16, 2024
7c67b38
Make fix
BenjaminBossan Oct 16, 2024
ea12e0d
Fix for possible g++ error
BenjaminBossan Oct 16, 2024
ec4b0d5
Add test for recompilation w/o hotswapping
BenjaminBossan Oct 18, 2024
e07323a
Merge branch 'main' into lora-hot-swapping
sayakpaul Oct 18, 2024
529a523
Merge branch 'main' into lora-hot-swapping
sayakpaul Oct 22, 2024
ac1346d
Merge branch 'main' into lora-hot-swapping
sayakpaul Oct 25, 2024
58b35ba
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Feb 6, 2025
d21a988
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Feb 6, 2025
488f2f0
Make hotswap work
BenjaminBossan Feb 7, 2025
ece3d0f
Merge branch 'main' into lora-hot-swapping
sayakpaul Feb 8, 2025
5ab1460
Address reviewer feedback:
BenjaminBossan Feb 10, 2025
bc157e6
Change order of test decorators
BenjaminBossan Feb 10, 2025
bd1da66
Split model and pipeline tests
BenjaminBossan Feb 11, 2025
119a8ed
Reviewer feedback: Move decorator to test classes
BenjaminBossan Feb 12, 2025
53c2f84
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Feb 12, 2025
a715559
Apply suggestions from code review
BenjaminBossan Feb 13, 2025
e40390d
Reviewer feedback: version check, TODO comment
BenjaminBossan Feb 13, 2025
1b834ec
Add enable_lora_hotswap method
BenjaminBossan Feb 14, 2025
4b01401
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Feb 14, 2025
2cd3665
Reviewer feedback: check _lora_loadable_modules
BenjaminBossan Feb 17, 2025
efbd820
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Feb 18, 2025
e735ac2
Revert changes in unet.py
BenjaminBossan Feb 18, 2025
69b637d
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Feb 21, 2025
3a6677c
Add possibility to ignore enabled at wrong time
BenjaminBossan Feb 21, 2025
a96f3fd
Fix docstrings
BenjaminBossan Feb 21, 2025
deab0eb
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Feb 27, 2025
2c6b435
Log possible PEFT error, test
BenjaminBossan Feb 27, 2025
ccb45f7
Raise helpful error if hotswap not supported
BenjaminBossan Feb 27, 2025
09e2ec7
Formatting
BenjaminBossan Feb 27, 2025
67ab6bf
More linter
BenjaminBossan Feb 27, 2025
f03fe6b
More ruff
BenjaminBossan Feb 27, 2025
2d407ca
Doc-builder complaint
BenjaminBossan Feb 27, 2025
6b59ecf
Update docstring:
BenjaminBossan Mar 3, 2025
f14146f
Merge branch 'main' into lora-hot-swapping
yiyixuxu Mar 3, 2025
a79876d
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Mar 5, 2025
c3c1bdf
Fix error in docstring
BenjaminBossan Mar 5, 2025
387ddf6
Update more methods with hotswap argument
BenjaminBossan Mar 7, 2025
7f72d0b
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Mar 7, 2025
dec4d10
Add hotswap argument to load_lora_into_transformer
BenjaminBossan Mar 11, 2025
204f521
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Mar 11, 2025
716f446
Extend docstrings
BenjaminBossan Mar 12, 2025
4d82111
Add version guards to tests
BenjaminBossan Mar 12, 2025
425cb39
Formatting
BenjaminBossan Mar 12, 2025
115c77d
Fix LoRA loading call to add prefix=None
BenjaminBossan Mar 12, 2025
5d90753
Run make fix-copies
BenjaminBossan Mar 12, 2025
62c1c13
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Mar 12, 2025
d6d23b8
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Mar 17, 2025
366632d
Add hot swap documentation to the docs
BenjaminBossan Mar 17, 2025
b181a47
Apply suggestions from code review
BenjaminBossan Mar 18, 2025
f2a6146
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Apr 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions docs/source/en/using-diffusers/loading_adapters.md
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,62 @@ Currently, [`~loaders.StableDiffusionLoraLoaderMixin.set_adapters`] only support

</Tip>

### Hot swapping LoRA adapters

A common use case when serving multiple adapters is to load one adapter first, generate images, then load another adapter, generate more images, load another adapter, etc. This workflow would normally require calling [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] and [`~loaders.StableDiffusionLoraLoaderMixin.set_adapters`] and possibly [`~loaders.peft.PeftAdapterMixin.delete_adapters`] to save on memory. Those are quite a few steps. Morever, if the model is compiled using `torch.compile`, performing these steps will result in recompilation, which takes time.

To better support this common workflow, diffusers offers the option to "hot swap" a LoRA adapter. This requires an adapter to already be loaded. Then, a new adapter can be hot swapped for the existing adapter, i.e. the weights are swapped in-place. This is more convenient, doesn't accumulate memory, and does not require recompilation, at least in some circumstances.

In general, hot swapping can be accomplished by passing `hotswap=True` when loading the LoRA adapter:

```python
pipe = ...
# load adapter 1 as normal
pipeline.load_lora_weights(file_name_adapter_1)
# generate some images with adapter 1
...
# now hot swap the 2nd adapter
pipeline.load_lora_weights(file_name_adapter_2, hotswap=True, adapter_name="default_0")
# generate images with adapter 2
```

Notice that we passed `adapter_name="default_0"`. This is the default adapter name given by diffusers and it is important that we indicate the name of the existing adapter. If you loaded the first adapter under a different name, pass that name instead.

<Tip warning={true}>

Hot swapping is currently not supported for the text encoder. If the LoRA adapter targets the text encoder, don't use this feature.

</Tip>

Now when it comes to compiled models, the same code as above may also work without triggering recompilation, but only if the second adapter targets the exact same ranks, has the exact same LoRA ranks and also scales. For most adapters, this is not the case. Therefore, it is necessary to go through one more step, as shown in this snippet:

```python
pipe = ...
# call this extra method
pipe.enable_lora_hotswap(target_rank=max_rank)
# now load adapter 1
pipe.load_lora_weights(file_name_adapter_1)
# now compile the unet of the pipeline
pipe.unet = torch.compile(pipeline.unet, ...)
# generate some images with adapter 1
...
# now hot swap adapter 2
pipeline.load_lora_weights(file_name_adapter_2, hotswap=True, adapter_name="default_0")
# generate images with adapter 2
```

By calling the [`~loaders.lora_base.LoraBaseMixin.enable_lora_hotswap`] method, diffusers makes it possible to hot swap the LoRA adapter without triggering recompilation. For this to work, call the method _before_ loading the first adapter. Also note that, as always, `torch.compile` has to be called _after_ loading the first adapter.

The `target_rank=max_rank` argument is important to let diffusers know what will be the maximum rank among all LoRA adapters that will be loaded. So if you have one adapter with rank 8 and another with rank 16, pass `target_rank=16`. By default, this value is 128. If in doubt, prefer a higher value.

Even after following these steps, there can be situations that will result in recompilation. Most notably, if the swapped in adapters targets more layers than the initial adapter, recompilation is needed. Try to load the adapter that targets most layers first. Read more about the limitations of hot swapping in the [PEFT documentation on hot swapping](https://huggingface.co/docs/peft/main/en/package_reference/hotswap#peft.utils.hotswap.hotswap_adapter).

<Tip>

To detect if the model was recompiled, move your code inside the `with torch._dynamo.config.patch(error_on_recompile=True)` context manager. If you detect recompilation despite following all the steps above, please open an issue on the [diffusers GitHub repository](https://github.com/huggingface/diffusers/issues) with a reproducer.

</Tip>

### Kohya and TheLastBen

Other popular LoRA trainers from the community include those by [Kohya](https://github.com/kohya-ss/sd-scripts/) and [TheLastBen](https://github.com/TheLastBen/fast-stable-diffusion). These trainers create different LoRA checkpoints than those trained by 🤗 Diffusers, but they can still be loaded in the same way.
Expand Down
25 changes: 25 additions & 0 deletions src/diffusers/loaders/lora_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ def _load_lora_into_text_encoder(
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
):
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
Expand All @@ -341,6 +342,10 @@ def _load_lora_into_text_encoder(
# their prefixes.
prefix = text_encoder_name if prefix is None else prefix

# Safe prefix to check with.
if hotswap and any(text_encoder_name in key for key in state_dict.keys()):
raise ValueError("At the moment, hotswapping is not supported for text encoders, please pass `hotswap=False`.")

# Load the layers corresponding to text encoder and make necessary adjustments.
if prefix is not None:
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
Expand Down Expand Up @@ -908,3 +913,23 @@ def lora_scale(self) -> float:
# property function that returns the lora scale which can be set at run time by the pipeline.
# if _lora_scale has not been set, return 1
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0

def enable_lora_hotswap(self, **kwargs) -> None:
"""Enables the possibility to hotswap LoRA adapters.

Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of
the loaded adapters differ.

Args:
target_rank (`int`):
The highest rank among all the adapters that will be loaded.
check_compiled (`str`, *optional*, defaults to `"error"`):
How to handle the case when the model is already compiled, which should generally be avoided. The
options are:
- "error" (default): raise an error
- "warn": issue a warning
- "ignore": do nothing
"""
for key, component in self.components.items():
if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules):
component.enable_lora_hotswap(**kwargs)
Loading
Loading