Skip to content

Commit a2071a1

Browse files
authored
[LoRA] introduce LoraBaseMixin to promote reusability. (#8670)
* introduce to promote reusability. * up * add more tests * up * remove comments. * fix fuse_nan test * clarify the scope of fuse_lora and unfuse_lora * remove space
1 parent d9f71ab commit a2071a1

File tree

13 files changed

+2255
-1708
lines changed

13 files changed

+2255
-1708
lines changed

docs/source/en/api/loaders/lora.md

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,13 @@ specific language governing permissions and limitations under the License.
1212

1313
# LoRA
1414

15-
LoRA is a fast and lightweight training method that inserts and trains a significantly smaller number of parameters instead of all the model parameters. This produces a smaller file (~100 MBs) and makes it easier to quickly train a model to learn a new concept. LoRA weights are typically loaded into the UNet, text encoder or both. There are two classes for loading LoRA weights:
15+
LoRA is a fast and lightweight training method that inserts and trains a significantly smaller number of parameters instead of all the model parameters. This produces a smaller file (~100 MBs) and makes it easier to quickly train a model to learn a new concept. LoRA weights are typically loaded into the denoiser, text encoder or both. The denoiser usually corresponds to a UNet ([`UNet2DConditionModel`], for example) or a Transformer ([`SD3Transformer2DModel`], for example). There are several classes for loading LoRA weights:
1616

1717
- [`LoraLoaderMixin`] provides functions for loading and unloading, fusing and unfusing, enabling and disabling, and more functions for managing LoRA weights. This class can be used with any model.
1818
- [`StableDiffusionXLLoraLoaderMixin`] is a [Stable Diffusion (SDXL)](../../api/pipelines/stable_diffusion/stable_diffusion_xl) version of the [`LoraLoaderMixin`] class for loading and saving LoRA weights. It can only be used with the SDXL model.
19+
- [`SD3LoraLoaderMixin`] provides similar functions for [Stable Diffusion 3](https://huggingface.co/blog/sd3).
20+
- [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`].
21+
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.
1922

2023
<Tip>
2124

@@ -29,4 +32,16 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
2932

3033
## StableDiffusionXLLoraLoaderMixin
3134

32-
[[autodoc]] loaders.lora.StableDiffusionXLLoraLoaderMixin
35+
[[autodoc]] loaders.lora.StableDiffusionXLLoraLoaderMixin
36+
37+
## SD3LoraLoaderMixin
38+
39+
[[autodoc]] loaders.lora.SD3LoraLoaderMixin
40+
41+
## AmusedLoraLoaderMixin
42+
43+
[[autodoc]] loaders.lora.AmusedLoraLoaderMixin
44+
45+
## LoraBaseMixin
46+
47+
[[autodoc]] loaders.lora_base.LoraBaseMixin

examples/amused/train_amused.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141

4242
import diffusers.optimization
4343
from diffusers import AmusedPipeline, AmusedScheduler, EMAModel, UVit2DModel, VQModel
44-
from diffusers.loaders import LoraLoaderMixin
44+
from diffusers.loaders import AmusedLoraLoaderMixin
4545
from diffusers.utils import is_wandb_available
4646

4747

@@ -532,7 +532,7 @@ def save_model_hook(models, weights, output_dir):
532532
weights.pop()
533533

534534
if transformer_lora_layers_to_save is not None or text_encoder_lora_layers_to_save is not None:
535-
LoraLoaderMixin.save_lora_weights(
535+
AmusedLoraLoaderMixin.save_lora_weights(
536536
output_dir,
537537
transformer_lora_layers=transformer_lora_layers_to_save,
538538
text_encoder_lora_layers=text_encoder_lora_layers_to_save,
@@ -566,11 +566,11 @@ def load_model_hook(models, input_dir):
566566
raise ValueError(f"unexpected save model: {model.__class__}")
567567

568568
if transformer is not None or text_encoder_ is not None:
569-
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
570-
LoraLoaderMixin.load_lora_into_text_encoder(
569+
lora_state_dict, network_alphas = AmusedLoraLoaderMixin.lora_state_dict(input_dir)
570+
AmusedLoraLoaderMixin.load_lora_into_text_encoder(
571571
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_
572572
)
573-
LoraLoaderMixin.load_lora_into_transformer(
573+
AmusedLoraLoaderMixin.load_lora_into_transformer(
574574
lora_state_dict, network_alphas=network_alphas, transformer=transformer
575575
)
576576

src/diffusers/loaders/__init__.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,18 @@ def text_encoder_attn_modules(text_encoder):
5555

5656
if is_torch_available():
5757
_import_structure["single_file_model"] = ["FromOriginalModelMixin"]
58+
_import_structure["transformer_sd3"] = ["SD3TransformerLoadersMixin"]
59+
5860
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
5961
_import_structure["utils"] = ["AttnProcsLayers"]
6062
if is_transformers_available():
6163
_import_structure["single_file"] = ["FromSingleFileMixin"]
62-
_import_structure["lora"] = ["LoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin", "SD3LoraLoaderMixin"]
64+
_import_structure["lora"] = [
65+
"AmusedLoraLoaderMixin",
66+
"LoraLoaderMixin",
67+
"SD3LoraLoaderMixin",
68+
"StableDiffusionXLLoraLoaderMixin",
69+
]
6370
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
6471
_import_structure["ip_adapter"] = ["IPAdapterMixin"]
6572

@@ -69,12 +76,18 @@ def text_encoder_attn_modules(text_encoder):
6976
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
7077
if is_torch_available():
7178
from .single_file_model import FromOriginalModelMixin
79+
from .transformer_sd3 import SD3TransformerLoadersMixin
7280
from .unet import UNet2DConditionLoadersMixin
7381
from .utils import AttnProcsLayers
7482

7583
if is_transformers_available():
7684
from .ip_adapter import IPAdapterMixin
77-
from .lora import LoraLoaderMixin, SD3LoraLoaderMixin, StableDiffusionXLLoraLoaderMixin
85+
from .lora import (
86+
AmusedLoraLoaderMixin,
87+
LoraLoaderMixin,
88+
SD3LoraLoaderMixin,
89+
StableDiffusionXLLoraLoaderMixin,
90+
)
7891
from .single_file import FromSingleFileMixin
7992
from .textual_inversion import TextualInversionLoaderMixin
8093

0 commit comments

Comments
 (0)