@@ -1326,14 +1326,17 @@ def load_lora_into_transformer(
13261326 )
13271327
13281328 # Load the layers corresponding to transformer.
1329- logger .info (f"Loading { cls .transformer_name } ." )
1330- transformer .load_lora_adapter (
1331- state_dict ,
1332- network_alphas = None ,
1333- adapter_name = adapter_name ,
1334- _pipeline = _pipeline ,
1335- low_cpu_mem_usage = low_cpu_mem_usage ,
1336- )
1329+ keys = list (state_dict .keys ())
1330+ only_text_encoder = all (key .startswith (cls .text_encoder_name ) for key in keys )
1331+ if not only_text_encoder :
1332+ logger .info (f"Loading { cls .transformer_name } ." )
1333+ transformer .load_lora_adapter (
1334+ state_dict ,
1335+ network_alphas = None ,
1336+ adapter_name = adapter_name ,
1337+ _pipeline = _pipeline ,
1338+ low_cpu_mem_usage = low_cpu_mem_usage ,
1339+ )
13371340
13381341 @classmethod
13391342 # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
@@ -1845,14 +1848,17 @@ def load_lora_into_transformer(
18451848 )
18461849
18471850 # Load the layers corresponding to transformer.
1848- logger .info (f"Loading { cls .transformer_name } ." )
1849- transformer .load_lora_adapter (
1850- state_dict ,
1851- network_alphas = network_alphas ,
1852- adapter_name = adapter_name ,
1853- _pipeline = _pipeline ,
1854- low_cpu_mem_usage = low_cpu_mem_usage ,
1855- )
1851+ keys = list (state_dict .keys ())
1852+ only_text_encoder = all (key .startswith (cls .text_encoder_name ) for key in keys )
1853+ if not only_text_encoder :
1854+ logger .info (f"Loading { cls .transformer_name } ." )
1855+ transformer .load_lora_adapter (
1856+ state_dict ,
1857+ network_alphas = network_alphas ,
1858+ adapter_name = adapter_name ,
1859+ _pipeline = _pipeline ,
1860+ low_cpu_mem_usage = low_cpu_mem_usage ,
1861+ )
18561862
18571863 @classmethod
18581864 # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
@@ -2118,7 +2124,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
21182124 text_encoder_name = TEXT_ENCODER_NAME
21192125
21202126 @classmethod
2121- # Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.load_lora_into_transformer with FluxTransformer2DModel->UVit2DModel
2127+ # Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.load_lora_into_transformer with FluxTransformer2DModel->UVit2DModel
21222128 def load_lora_into_transformer (
21232129 cls , state_dict , network_alphas , transformer , adapter_name = None , _pipeline = None , low_cpu_mem_usage = False
21242130 ):
@@ -2149,14 +2155,17 @@ def load_lora_into_transformer(
21492155 )
21502156
21512157 # Load the layers corresponding to transformer.
2152- logger .info (f"Loading { cls .transformer_name } ." )
2153- transformer .load_lora_adapter (
2154- state_dict ,
2155- network_alphas = network_alphas ,
2156- adapter_name = adapter_name ,
2157- _pipeline = _pipeline ,
2158- low_cpu_mem_usage = low_cpu_mem_usage ,
2159- )
2158+ keys = list (state_dict .keys ())
2159+ only_text_encoder = all (key .startswith (cls .text_encoder_name ) for key in keys )
2160+ if not only_text_encoder :
2161+ logger .info (f"Loading { cls .transformer_name } ." )
2162+ transformer .load_lora_adapter (
2163+ state_dict ,
2164+ network_alphas = network_alphas ,
2165+ adapter_name = adapter_name ,
2166+ _pipeline = _pipeline ,
2167+ low_cpu_mem_usage = low_cpu_mem_usage ,
2168+ )
21602169
21612170 @classmethod
21622171 # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
@@ -2537,14 +2546,17 @@ def load_lora_into_transformer(
25372546 )
25382547
25392548 # Load the layers corresponding to transformer.
2540- logger .info (f"Loading { cls .transformer_name } ." )
2541- transformer .load_lora_adapter (
2542- state_dict ,
2543- network_alphas = None ,
2544- adapter_name = adapter_name ,
2545- _pipeline = _pipeline ,
2546- low_cpu_mem_usage = low_cpu_mem_usage ,
2547- )
2549+ keys = list (state_dict .keys ())
2550+ only_text_encoder = all (key .startswith (cls .text_encoder_name ) for key in keys )
2551+ if not only_text_encoder :
2552+ logger .info (f"Loading { cls .transformer_name } ." )
2553+ transformer .load_lora_adapter (
2554+ state_dict ,
2555+ network_alphas = None ,
2556+ adapter_name = adapter_name ,
2557+ _pipeline = _pipeline ,
2558+ low_cpu_mem_usage = low_cpu_mem_usage ,
2559+ )
25482560
25492561 @classmethod
25502562 # Adapted from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights without support for text encoder
0 commit comments