-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Added unload_lora_weights to StableDiffusionPipeline #11172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -502,26 +502,35 @@ def _best_guess_weight_name(cls, *args, **kwargs): | |||||
|
|
||||||
| def unload_lora_weights(self): | ||||||
| """ | ||||||
| Unloads the LoRA parameters. | ||||||
| Unloads all LoRA adapters from memory and resets internal adapter configuration. | ||||||
| This allows reloading the same adapter names without conflict. | ||||||
|
|
||||||
| Examples: | ||||||
|
|
||||||
| ```python | ||||||
| >>> # Assuming `pipeline` is already loaded with the LoRA parameters. | ||||||
| >>> pipeline.unload_lora_weights() | ||||||
| >>> ... | ||||||
| ``` | ||||||
| Example: | ||||||
| pipe.unload_lora_weights() | ||||||
| pipe.load_lora_weights("adapter_name") # safe to reload | ||||||
| """ | ||||||
| if not USE_PEFT_BACKEND: | ||||||
| raise ValueError("PEFT backend is required for this method.") | ||||||
|
|
||||||
| for component in self._lora_loadable_modules: | ||||||
| model = getattr(self, component, None) | ||||||
| if model is not None: | ||||||
| if issubclass(model.__class__, ModelMixin): | ||||||
| model.unload_lora() | ||||||
| elif issubclass(model.__class__, PreTrainedModel): | ||||||
| _remove_text_encoder_monkey_patch(model) | ||||||
| if model is None: | ||||||
| continue | ||||||
|
|
||||||
| # For diffusers-style models | ||||||
| if issubclass(model.__class__, ModelMixin): | ||||||
| model.unload_lora() | ||||||
| if hasattr(model, "peft_config"): | ||||||
| model.peft_config.clear() | ||||||
|
Comment on lines
+523
to
+524
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
diffusers/src/diffusers/loaders/peft.py Lines 688 to 689 in 54dac3a
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 here. |
||||||
| if hasattr(model, "active_adapter"): | ||||||
| model.active_adapter = None | ||||||
|
Comment on lines
+525
to
+526
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||
|
|
||||||
| # For transformers/PEFT models | ||||||
| elif issubclass(model.__class__, PreTrainedModel): | ||||||
| _remove_text_encoder_monkey_patch(model) | ||||||
|
|
||||||
| torch.cuda.empty_cache() | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will not work on non-cuda devices, we have device agnostic diffusers/src/diffusers/utils/testing_utils.py Lines 1198 to 1199 in 54dac3a
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In testing, the basic unload–reload cycle works fine on the main branch when you use a single LoRA repeatedly. However, the issue we observed in production arises when you need to switch between different LoRA adapters within the same session or rapidly update weights on an adapter without restarting the pipeline. Specifically, while the current implementation of unload_lora_weights() appears to delete the peft_config (and thus, the adapter state) in many cases, in a dynamic server environment the underlying transformer modules may still retain residual adapter registration information (e.g., within the internal state of the BaseTunerLayer). This can lead to subtle issues where: If you attempt to load a new adapter (or reload the same adapter with updated weights) without a full pipeline reset, you might encounter errors such as "Adapter name ... already in use in the transformer" even though the adapter isn’t visible via get_list_adapters(). In scenarios with rapid switching or dynamic weight updates, any residual state in the transformer (which isn’t fully cleared by the current unload code) can cause unpredictable behavior or slight performance degradation. This patch was intended to fully clear all traces of adapter registrations from the model’s transformer modules so that subsequent adapter loads truly start from a clean slate. This is particularly important when you want to ensure that no stale parameters interfere with new weights in a production environment where you may be swapping adapters on the fly. If the changes result in “the exact same functionality” in simple tests, that may be because the basic unload–reload cycle doesn’t fully expose the issue in our unit tests. The additional code is aimed at more complex use cases—such as multiple adapters being loaded and unloaded sequentially in a long-running inference service—where even a small residual state can eventually lead to conflicts or degraded generation quality. In summary, while the minimal test passes on the main branch, the use case we’re targeting is: Dynamic production pipelines where LoRA adapters need to be swapped or updated without restarting the server. Preventing state leakage in long-running sessions, ensuring that the transformer’s internal adapter registrations are fully reset between loads. For more context, here's the GitHub repo where I use dynamic LoRA loading/unloading in a production-style API setup: https://github.com/dhawan98/LoRA-Server. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 on using a backend-agnostic cache clearing method. But before doing so, could you demonstrate the memory saved with and without this change on a GPU? |
||||||
|
|
||||||
|
|
||||||
| def fuse_lora( | ||||||
| self, | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| from diffusers import StableDiffusionPipeline | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This needs to be made more generic and needs to go to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the feedback Sayak. You're right — current implementation doesn't fully handle all edge cases, especially in dynamic adapter switching scenarios. We'll prepare a new pull request that addresses the comments and ensures proper cleanup of residual adapter state. Appreciate the detailed reviews — more soon. |
||
| import torch | ||
| import pytest | ||
|
|
||
| def test_unload_reload_same_adapter(): | ||
| pipe = StableDiffusionPipeline.from_pretrained( | ||
| "runwayml/stable-diffusion-v1-5", | ||
| torch_dtype=torch.float32 | ||
| ).to("cpu") | ||
|
|
||
| lora_repo = "latent-consistency/lcm-lora-sdv1-5" | ||
|
|
||
| # Load and activate LoRA | ||
| pipe.load_lora_weights(lora_repo) | ||
| adapters = pipe.get_list_adapters() | ||
| adapter_name = list(adapters["unet"])[0] | ||
|
|
||
| pipe.set_adapters([adapter_name], [1.0]) | ||
|
|
||
| # Unload | ||
| pipe.unload_lora_weights() | ||
|
|
||
| # Reload | ||
| pipe.load_lora_weights(lora_repo) | ||
| adapters = pipe.get_list_adapters() | ||
| adapter_name = list(adapters["unet"])[0] | ||
|
|
||
| pipe.set_adapters([adapter_name], [0.8]) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems unnecessary after the docstring change.