@@ -843,11 +843,11 @@ def save_lora_weights(
843843
844844 if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers ):
845845 raise ValueError (
846- "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
846+ "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
847847 )
848848
849849 if unet_lora_layers :
850- state_dict .update (cls .pack_weights (unet_lora_layers , "unet" ))
850+ state_dict .update (cls .pack_weights (unet_lora_layers , cls . unet_name ))
851851
852852 if text_encoder_lora_layers :
853853 state_dict .update (cls .pack_weights (text_encoder_lora_layers , "text_encoder" ))
@@ -1210,10 +1210,11 @@ def load_lora_into_text_encoder(
12101210 )
12111211
12121212 @classmethod
1213+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.save_lora_weights with unet->transformer
12131214 def save_lora_weights (
12141215 cls ,
12151216 save_directory : Union [str , os .PathLike ],
1216- transformer_lora_layers : Dict [str , torch .nn .Module ] = None ,
1217+ transformer_lora_layers : Dict [str , Union [ torch .nn .Module , torch . Tensor ] ] = None ,
12171218 text_encoder_lora_layers : Dict [str , Union [torch .nn .Module , torch .Tensor ]] = None ,
12181219 text_encoder_2_lora_layers : Dict [str , Union [torch .nn .Module , torch .Tensor ]] = None ,
12191220 is_main_process : bool = True ,
@@ -1262,7 +1263,6 @@ def save_lora_weights(
12621263 if text_encoder_2_lora_layers :
12631264 state_dict .update (cls .pack_weights (text_encoder_2_lora_layers , "text_encoder_2" ))
12641265
1265- # Save the model
12661266 cls .write_lora_layers (
12671267 state_dict = state_dict ,
12681268 save_directory = save_directory ,
@@ -1272,6 +1272,7 @@ def save_lora_weights(
12721272 safe_serialization = safe_serialization ,
12731273 )
12741274
1275+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer
12751276 def fuse_lora (
12761277 self ,
12771278 components : List [str ] = ["transformer" , "text_encoder" , "text_encoder_2" ],
@@ -1315,6 +1316,7 @@ def fuse_lora(
13151316 components = components , lora_scale = lora_scale , safe_fusing = safe_fusing , adapter_names = adapter_names
13161317 )
13171318
1319+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer
13181320 def unfuse_lora (self , components : List [str ] = ["transformer" , "text_encoder" , "text_encoder_2" ], ** kwargs ):
13191321 r"""
13201322 Reverses the effect of
@@ -1328,7 +1330,7 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "t
13281330
13291331 Args:
13301332 components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
1331- unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
1333+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
13321334 unfuse_text_encoder (`bool`, defaults to `True`):
13331335 Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
13341336 LoRA parameters then it won't have any effect.
@@ -2833,6 +2835,7 @@ def save_lora_weights(
28332835 safe_serialization = safe_serialization ,
28342836 )
28352837
2838+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
28362839 def fuse_lora (
28372840 self ,
28382841 components : List [str ] = ["transformer" ],
@@ -2876,6 +2879,7 @@ def fuse_lora(
28762879 components = components , lora_scale = lora_scale , safe_fusing = safe_fusing , adapter_names = adapter_names
28772880 )
28782881
2882+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
28792883 def unfuse_lora (self , components : List [str ] = ["transformer" ], ** kwargs ):
28802884 r"""
28812885 Reverses the effect of
@@ -3136,6 +3140,7 @@ def save_lora_weights(
31363140 safe_serialization = safe_serialization ,
31373141 )
31383142
3143+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
31393144 def fuse_lora (
31403145 self ,
31413146 components : List [str ] = ["transformer" ],
@@ -3179,6 +3184,7 @@ def fuse_lora(
31793184 components = components , lora_scale = lora_scale , safe_fusing = safe_fusing , adapter_names = adapter_names
31803185 )
31813186
3187+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
31823188 def unfuse_lora (self , components : List [str ] = ["transformer" ], ** kwargs ):
31833189 r"""
31843190 Reverses the effect of
@@ -3439,6 +3445,7 @@ def save_lora_weights(
34393445 safe_serialization = safe_serialization ,
34403446 )
34413447
3448+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
34423449 def fuse_lora (
34433450 self ,
34443451 components : List [str ] = ["transformer" ],
@@ -3482,6 +3489,7 @@ def fuse_lora(
34823489 components = components , lora_scale = lora_scale , safe_fusing = safe_fusing , adapter_names = adapter_names
34833490 )
34843491
3492+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
34853493 def unfuse_lora (self , components : List [str ] = ["transformer" ], ** kwargs ):
34863494 r"""
34873495 Reverses the effect of
@@ -3745,6 +3753,7 @@ def save_lora_weights(
37453753 safe_serialization = safe_serialization ,
37463754 )
37473755
3756+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
37483757 def fuse_lora (
37493758 self ,
37503759 components : List [str ] = ["transformer" ],
@@ -3788,6 +3797,7 @@ def fuse_lora(
37883797 components = components , lora_scale = lora_scale , safe_fusing = safe_fusing , adapter_names = adapter_names
37893798 )
37903799
3800+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
37913801 def unfuse_lora (self , components : List [str ] = ["transformer" ], ** kwargs ):
37923802 r"""
37933803 Reverses the effect of
0 commit comments