@@ -1807,7 +1807,7 @@ def load_lora_weights(
18071807 raise ValueError ("Invalid LoRA checkpoint." )
18081808
18091809 transformer_lora_state_dict = {
1810- k : state_dict .pop (k )
1810+ k : state_dict .get (k )
18111811 for k in list (state_dict .keys ())
18121812 if k .startswith (self .transformer_name ) and "lora" in k
18131813 }
@@ -1819,29 +1819,33 @@ def load_lora_weights(
18191819 }
18201820
18211821 transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer
1822- has_param_with_expanded_shape = self ._maybe_expand_transformer_param_shape_or_error_ (
1823- transformer , transformer_lora_state_dict , transformer_norm_state_dict
1824- )
1822+ has_param_with_expanded_shape = False
1823+ if len (transformer_lora_state_dict ) > 0 :
1824+ has_param_with_expanded_shape = self ._maybe_expand_transformer_param_shape_or_error_ (
1825+ transformer , transformer_lora_state_dict , transformer_norm_state_dict
1826+ )
18251827
18261828 if has_param_with_expanded_shape :
18271829 logger .info (
18281830 "The LoRA weights contain parameters that have different shapes that expected by the transformer. "
18291831 "As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. "
18301832 "To get a comprehensive list of parameter names that were modified, enable debug logging."
18311833 )
1832- transformer_lora_state_dict = self ._maybe_expand_lora_state_dict (
1833- transformer = transformer , lora_state_dict = transformer_lora_state_dict
1834- )
1835-
18361834 if len (transformer_lora_state_dict ) > 0 :
1837- self .load_lora_into_transformer (
1838- transformer_lora_state_dict ,
1839- network_alphas = network_alphas ,
1840- transformer = transformer ,
1841- adapter_name = adapter_name ,
1842- _pipeline = self ,
1843- low_cpu_mem_usage = low_cpu_mem_usage ,
1835+ transformer_lora_state_dict = self ._maybe_expand_lora_state_dict (
1836+ transformer = transformer , lora_state_dict = transformer_lora_state_dict
18441837 )
1838+ for k in transformer_lora_state_dict :
1839+ state_dict .update ({k : transformer_lora_state_dict [k ]})
1840+
1841+ self .load_lora_into_transformer (
1842+ state_dict ,
1843+ network_alphas = network_alphas ,
1844+ transformer = transformer ,
1845+ adapter_name = adapter_name ,
1846+ _pipeline = self ,
1847+ low_cpu_mem_usage = low_cpu_mem_usage ,
1848+ )
18451849
18461850 if len (transformer_norm_state_dict ) > 0 :
18471851 transformer ._transformer_norm_layers = self ._load_norm_into_transformer (
0 commit comments