Skip to content

Commit cb78553

Browse files
authored
Update lora_pipeline.py
1 parent b38450d commit cb78553

File tree

1 file changed

+22
-22
lines changed

1 file changed

+22
-22
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ def fuse_lora(
455455
```
456456
"""
457457
super().fuse_lora(
458-
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
458+
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names, **kwargs
459459
)
460460

461461
def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs):
@@ -476,7 +476,7 @@ def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs
476476
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
477477
LoRA parameters then it won't have any effect.
478478
"""
479-
super().unfuse_lora(components=components)
479+
super().unfuse_lora(components=components, **kwargs)
480480

481481

482482
class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
@@ -904,7 +904,7 @@ def fuse_lora(
904904
```
905905
"""
906906
super().fuse_lora(
907-
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
907+
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names, **kwargs
908908
)
909909

910910
def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_encoder_2"], **kwargs):
@@ -925,7 +925,7 @@ def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_enc
925925
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
926926
LoRA parameters then it won't have any effect.
927927
"""
928-
super().unfuse_lora(components=components)
928+
super().unfuse_lora(components=components, **kwargs)
929929

930930

931931
class SD3LoraLoaderMixin(LoraBaseMixin):
@@ -1312,7 +1312,7 @@ def fuse_lora(
13121312
```
13131313
"""
13141314
super().fuse_lora(
1315-
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
1315+
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names, **kwargs
13161316
)
13171317

13181318
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs):
@@ -1333,7 +1333,7 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "t
13331333
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
13341334
LoRA parameters then it won't have any effect.
13351335
"""
1336-
super().unfuse_lora(components=components)
1336+
super().unfuse_lora(components=components, **kwargs)
13371337

13381338

13391339
class FluxLoraLoaderMixin(LoraBaseMixin):
@@ -1847,7 +1847,7 @@ def fuse_lora(
18471847
)
18481848

18491849
super().fuse_lora(
1850-
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
1850+
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names, **kwargs
18511851
)
18521852

18531853
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
@@ -1868,7 +1868,7 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], *
18681868
if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
18691869
transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
18701870

1871-
super().unfuse_lora(components=components)
1871+
super().unfuse_lora(components=components, **kwargs)
18721872

18731873
# We override this here account for `_transformer_norm_layers` and `_overwritten_params`.
18741874
def unload_lora_weights(self, reset_to_overwritten_params=False):
@@ -2570,7 +2570,7 @@ def fuse_lora(
25702570
```
25712571
"""
25722572
super().fuse_lora(
2573-
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
2573+
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names, **kwargs
25742574
)
25752575

25762576
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
@@ -2588,7 +2588,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
25882588
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
25892589
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
25902590
"""
2591-
super().unfuse_lora(components=components)
2591+
super().unfuse_lora(components=components, **kwargs)
25922592

25932593

25942594
class Mochi1LoraLoaderMixin(LoraBaseMixin):
@@ -2873,7 +2873,7 @@ def fuse_lora(
28732873
```
28742874
"""
28752875
super().fuse_lora(
2876-
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
2876+
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names, **kwargs
28772877
)
28782878

28792879
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
@@ -2891,7 +2891,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
28912891
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
28922892
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
28932893
"""
2894-
super().unfuse_lora(components=components)
2894+
super().unfuse_lora(components=components, **kwargs)
28952895

28962896

28972897
class LTXVideoLoraLoaderMixin(LoraBaseMixin):
@@ -3176,7 +3176,7 @@ def fuse_lora(
31763176
```
31773177
"""
31783178
super().fuse_lora(
3179-
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
3179+
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names, **kwargs
31803180
)
31813181

31823182
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
@@ -3194,7 +3194,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
31943194
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
31953195
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
31963196
"""
3197-
super().unfuse_lora(components=components)
3197+
super().unfuse_lora(components=components, **kwargs)
31983198

31993199

32003200
class SanaLoraLoaderMixin(LoraBaseMixin):
@@ -3479,7 +3479,7 @@ def fuse_lora(
34793479
```
34803480
"""
34813481
super().fuse_lora(
3482-
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
3482+
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names, **kwargs
34833483
)
34843484

34853485
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
@@ -3497,7 +3497,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
34973497
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
34983498
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
34993499
"""
3500-
super().unfuse_lora(components=components)
3500+
super().unfuse_lora(components=components, **kwargs)
35013501

35023502

35033503
class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
@@ -3785,7 +3785,7 @@ def fuse_lora(
37853785
```
37863786
"""
37873787
super().fuse_lora(
3788-
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
3788+
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names, **kwargs
37893789
)
37903790

37913791
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
@@ -3803,7 +3803,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
38033803
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
38043804
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
38053805
"""
3806-
super().unfuse_lora(components=components)
3806+
super().unfuse_lora(components=components, **kwargs)
38073807

38083808

38093809
class Lumina2LoraLoaderMixin(LoraBaseMixin):
@@ -4093,7 +4093,7 @@ def fuse_lora(
40934093
```
40944094
"""
40954095
super().fuse_lora(
4096-
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
4096+
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names, **kwargs
40974097
)
40984098

40994099
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
@@ -4112,7 +4112,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
41124112
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
41134113
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
41144114
"""
4115-
super().unfuse_lora(components=components)
4115+
super().unfuse_lora(components=components, **kwargs)
41164116

41174117

41184118
class WanLoraLoaderMixin(LoraBaseMixin):
@@ -4398,7 +4398,7 @@ def fuse_lora(
43984398
```
43994399
"""
44004400
super().fuse_lora(
4401-
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
4401+
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names, **kwargs
44024402
)
44034403

44044404
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
@@ -4417,7 +4417,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
44174417
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
44184418
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
44194419
"""
4420-
super().unfuse_lora(components=components)
4420+
super().unfuse_lora(components=components, **kwargs)
44214421

44224422

44234423
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):

0 commit comments

Comments
 (0)