@@ -1263,17 +1263,13 @@ def load_lora_weights(
12631263 if not is_correct_format :
12641264 raise ValueError ("Invalid LoRA checkpoint." )
12651265
1266- transformer_state_dict = {k : v for k , v in state_dict .items () if "transformer." in k }
1267- if len (transformer_state_dict ) > 0 :
1268- self .load_lora_into_transformer (
1269- state_dict ,
1270- transformer = getattr (self , self .transformer_name )
1271- if not hasattr (self , "transformer" )
1272- else self .transformer ,
1273- adapter_name = adapter_name ,
1274- _pipeline = self ,
1275- low_cpu_mem_usage = low_cpu_mem_usage ,
1276- )
1266+ self .load_lora_into_transformer (
1267+ state_dict ,
1268+ transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer ,
1269+ adapter_name = adapter_name ,
1270+ _pipeline = self ,
1271+ low_cpu_mem_usage = low_cpu_mem_usage ,
1272+ )
12771273 self .load_lora_into_text_encoder (
12781274 state_dict ,
12791275 network_alphas = None ,
@@ -1809,12 +1805,12 @@ def load_lora_weights(
18091805 transformer_lora_state_dict = {
18101806 k : state_dict .get (k )
18111807 for k in list (state_dict .keys ())
1812- if k .startswith (self .transformer_name ) and "lora" in k
1808+ if k .startswith (f" { self .transformer_name } ." ) and "lora" in k
18131809 }
18141810 transformer_norm_state_dict = {
18151811 k : state_dict .pop (k )
18161812 for k in list (state_dict .keys ())
1817- if k .startswith (self .transformer_name )
1813+ if k .startswith (f" { self .transformer_name } ." )
18181814 and any (norm_key in k for norm_key in self ._control_lora_supported_norm_keys )
18191815 }
18201816
0 commit comments