@@ -1807,7 +1807,7 @@ def load_lora_weights(
18071807            raise  ValueError ("Invalid LoRA checkpoint." )
18081808
18091809        transformer_lora_state_dict  =  {
1810-             k : state_dict .pop (k )
1810+             k : state_dict .get (k )
18111811            for  k  in  list (state_dict .keys ())
18121812            if  k .startswith (self .transformer_name ) and  "lora"  in  k 
18131813        }
@@ -1819,29 +1819,33 @@ def load_lora_weights(
18191819        }
18201820
18211821        transformer  =  getattr (self , self .transformer_name ) if  not  hasattr (self , "transformer" ) else  self .transformer 
1822-         has_param_with_expanded_shape  =  self ._maybe_expand_transformer_param_shape_or_error_ (
1823-             transformer , transformer_lora_state_dict , transformer_norm_state_dict 
1824-         )
1822+         has_param_with_expanded_shape  =  False 
1823+         if  len (transformer_lora_state_dict ) >  0 :
1824+             has_param_with_expanded_shape  =  self ._maybe_expand_transformer_param_shape_or_error_ (
1825+                 transformer , transformer_lora_state_dict , transformer_norm_state_dict 
1826+             )
18251827
18261828        if  has_param_with_expanded_shape :
18271829            logger .info (
18281830                "The LoRA weights contain parameters that have different shapes that expected by the transformer. " 
18291831                "As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. " 
18301832                "To get a comprehensive list of parameter names that were modified, enable debug logging." 
18311833            )
1832-         transformer_lora_state_dict  =  self ._maybe_expand_lora_state_dict (
1833-             transformer = transformer , lora_state_dict = transformer_lora_state_dict 
1834-         )
1835- 
18361834        if  len (transformer_lora_state_dict ) >  0 :
1837-             self .load_lora_into_transformer (
1838-                 transformer_lora_state_dict ,
1839-                 network_alphas = network_alphas ,
1840-                 transformer = transformer ,
1841-                 adapter_name = adapter_name ,
1842-                 _pipeline = self ,
1843-                 low_cpu_mem_usage = low_cpu_mem_usage ,
1835+             transformer_lora_state_dict  =  self ._maybe_expand_lora_state_dict (
1836+                 transformer = transformer , lora_state_dict = transformer_lora_state_dict 
18441837            )
1838+             for  k  in  transformer_lora_state_dict :
1839+                 state_dict .update ({k : transformer_lora_state_dict [k ]})
1840+ 
1841+         self .load_lora_into_transformer (
1842+             state_dict ,
1843+             network_alphas = network_alphas ,
1844+             transformer = transformer ,
1845+             adapter_name = adapter_name ,
1846+             _pipeline = self ,
1847+             low_cpu_mem_usage = low_cpu_mem_usage ,
1848+         )
18451849
18461850        if  len (transformer_norm_state_dict ) >  0 :
18471851            transformer ._transformer_norm_layers  =  self ._load_norm_into_transformer (
0 commit comments