@@ -1863,6 +1863,9 @@ def load_lora_weights(
18631863                "As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. " 
18641864                "To get a comprehensive list of parameter names that were modified, enable debug logging." 
18651865            )
1866+         transformer_lora_state_dict  =  self ._maybe_expand_lora_state_dict (
1867+             transformer = transformer , lora_state_dict = transformer_lora_state_dict 
1868+         )
18661869
18671870        if  len (transformer_lora_state_dict ) >  0 :
18681871            self .load_lora_into_transformer (
@@ -2309,16 +2312,17 @@ def _maybe_expand_transformer_param_shape_or_error_(
23092312
23102313        # Expand transformer parameter shapes if they don't match lora 
23112314        has_param_with_shape_update  =  False 
2312- 
2315+          is_peft_loaded   =   getattr ( transformer ,  "peft_config" ,  None )  is   not   None 
23132316        for  name , module  in  transformer .named_modules ():
23142317            if  isinstance (module , torch .nn .Linear ):
23152318                module_weight  =  module .weight .data 
23162319                module_bias  =  module .bias .data  if  module .bias  is  not None  else  None 
23172320                bias  =  module_bias  is  not None 
23182321
2319-                 lora_A_weight_name  =  f"{ name }  
2320-                 lora_B_weight_name  =  f"{ name }  
2321-                 if  lora_A_weight_name  not  in state_dict .keys ():
2322+                 lora_base_name  =  name .replace (".base_layer" , "" ) if  is_peft_loaded  else  name 
2323+                 lora_A_weight_name  =  f"{ lora_base_name }  
2324+                 lora_B_weight_name  =  f"{ lora_base_name }  
2325+                 if  lora_A_weight_name  not  in state_dict :
23222326                    continue 
23232327
23242328                in_features  =  state_dict [lora_A_weight_name ].shape [1 ]
@@ -2329,56 +2333,105 @@ def _maybe_expand_transformer_param_shape_or_error_(
23292333                    continue 
23302334
23312335                module_out_features , module_in_features  =  module_weight .shape 
2332-                 if   out_features   <   module_out_features   or   in_features   <   module_in_features : 
2333-                      raise   NotImplementedError ( 
2334-                          f"Only LoRAs with input/output features higher than the current module's input/output features " 
2335-                         f"are currently supported. The provided LoRA contains  { in_features = }  and  { out_features = } , which "  
2336-                         f"are lower than  { module_in_features = }  and  { module_out_features = } . If you require support for  " 
2337-                         f"this please open an issue at https://github.com/huggingface/diffusers/issues. " 
2336+                 debug_message   =   "" 
2337+                 if   in_features   >   module_in_features : 
2338+                     debug_message   +=  ( 
2339+                         f'Expanding the nn.Linear input/output features for module=" { name } " because the provided LoRA '  
2340+                         f"checkpoint contains higher number of features than expected. The number of input_features will be  " 
2341+                         f"expanded from  { module_in_features }  to  { in_features }  
23382342                    )
2339- 
2340-                 debug_message  =  (
2341-                     f'Expanding the nn.Linear input/output features for module="{ name }  
2342-                     f"checkpoint contains higher number of features than expected. The number of input_features will be " 
2343-                     f"expanded from { module_in_features } { in_features }  
2344-                 )
2345-                 if  module_out_features  !=  out_features :
2343+                 if  out_features  >  module_out_features :
23462344                    debug_message  +=  (
23472345                        ", and the number of output features will be " 
23482346                        f"expanded from { module_out_features } { out_features }  
23492347                    )
23502348                else :
23512349                    debug_message  +=  "." 
2352-                 logger .debug (debug_message )
2350+                 if  debug_message :
2351+                     logger .debug (debug_message )
2352+ 
2353+                 if  out_features  >  module_out_features  or  in_features  >  module_in_features :
2354+                     has_param_with_shape_update  =  True 
2355+                     parent_module_name , _ , current_module_name  =  name .rpartition ("." )
2356+                     parent_module  =  transformer .get_submodule (parent_module_name )
2357+ 
2358+                     with  torch .device ("meta" ):
2359+                         expanded_module  =  torch .nn .Linear (
2360+                             in_features , out_features , bias = bias , dtype = module_weight .dtype 
2361+                         )
2362+                     # Only weights are expanded and biases are not. This is because only the input dimensions 
2363+                     # are changed while the output dimensions remain the same. The shape of the weight tensor 
2364+                     # is (out_features, in_features), while the shape of bias tensor is (out_features,), which 
2365+                     # explains the reason why only weights are expanded. 
2366+                     new_weight  =  torch .zeros_like (
2367+                         expanded_module .weight .data , device = module_weight .device , dtype = module_weight .dtype 
2368+                     )
2369+                     slices  =  tuple (slice (0 , dim ) for  dim  in  module_weight .shape )
2370+                     new_weight [slices ] =  module_weight 
2371+                     tmp_state_dict  =  {"weight" : new_weight }
2372+                     if  module_bias  is  not None :
2373+                         tmp_state_dict ["bias" ] =  module_bias 
2374+                     expanded_module .load_state_dict (tmp_state_dict , strict = True , assign = True )
2375+ 
2376+                     setattr (parent_module , current_module_name , expanded_module )
2377+ 
2378+                     del  tmp_state_dict 
2379+ 
2380+                     if  current_module_name  in  _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX :
2381+                         attribute_name  =  _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX [current_module_name ]
2382+                         new_value  =  int (expanded_module .weight .data .shape [1 ])
2383+                         old_value  =  getattr (transformer .config , attribute_name )
2384+                         setattr (transformer .config , attribute_name , new_value )
2385+                         logger .info (
2386+                             f"Set the { attribute_name } { new_value } { old_value }  
2387+                         )
23532388
2354-                 has_param_with_shape_update  =  True 
2355-                 parent_module_name , _ , current_module_name  =  name .rpartition ("." )
2356-                 parent_module  =  transformer .get_submodule (parent_module_name )
2389+         return  has_param_with_shape_update 
23572390
2358-                 # TODO: consider initializing this under meta device for optims. 
2359-                 expanded_module  =  torch .nn .Linear (
2360-                     in_features , out_features , bias = bias , device = module_weight .device , dtype = module_weight .dtype 
2361-                 )
2362-                 # Only weights are expanded and biases are not. 
2363-                 new_weight  =  torch .zeros_like (
2364-                     expanded_module .weight .data , device = module_weight .device , dtype = module_weight .dtype 
2391+     @classmethod  
2392+     def  _maybe_expand_lora_state_dict (cls , transformer , lora_state_dict ):
2393+         expanded_module_names  =  set ()
2394+         transformer_state_dict  =  transformer .state_dict ()
2395+         prefix  =  f"{ cls .transformer_name }  
2396+ 
2397+         lora_module_names  =  [
2398+             key [: - len (".lora_A.weight" )] for  key  in  lora_state_dict  if  key .endswith (".lora_A.weight" )
2399+         ]
2400+         lora_module_names  =  [name [len (prefix ) :] for  name  in  lora_module_names  if  name .startswith (prefix )]
2401+         lora_module_names  =  sorted (set (lora_module_names ))
2402+         transformer_module_names  =  sorted ({name  for  name , _  in  transformer .named_modules ()})
2403+         unexpected_modules  =  set (lora_module_names ) -  set (transformer_module_names )
2404+         if  unexpected_modules :
2405+             logger .debug (f"Found unexpected modules: { unexpected_modules }  )
2406+ 
2407+         is_peft_loaded  =  getattr (transformer , "peft_config" , None ) is  not None 
2408+         for  k  in  lora_module_names :
2409+             if  k  in  unexpected_modules :
2410+                 continue 
2411+ 
2412+             base_param_name  =  (
2413+                 f"{ k .replace (prefix , '' )}   if  is_peft_loaded  else  f"{ k .replace (prefix , '' )}  
2414+             )
2415+             base_weight_param  =  transformer_state_dict [base_param_name ]
2416+             lora_A_param  =  lora_state_dict [f"{ prefix } { k }  ]
2417+ 
2418+             if  base_weight_param .shape [1 ] >  lora_A_param .shape [1 ]:
2419+                 shape  =  (lora_A_param .shape [0 ], base_weight_param .shape [1 ])
2420+                 expanded_state_dict_weight  =  torch .zeros (shape , device = base_weight_param .device )
2421+                 expanded_state_dict_weight [:, : lora_A_param .shape [1 ]].copy_ (lora_A_param )
2422+                 lora_state_dict [f"{ prefix } { k }  ] =  expanded_state_dict_weight 
2423+                 expanded_module_names .add (k )
2424+             elif  base_weight_param .shape [1 ] <  lora_A_param .shape [1 ]:
2425+                 raise  NotImplementedError (
2426+                     f"This LoRA param ({ k } { lora_A_param .shape }  
23652427                )
2366-                 slices  =  tuple (slice (0 , dim ) for  dim  in  module_weight .shape )
2367-                 new_weight [slices ] =  module_weight 
2368-                 expanded_module .weight .data .copy_ (new_weight )
2369-                 if  module_bias  is  not None :
2370-                     expanded_module .bias .data .copy_ (module_bias )
2371- 
2372-                 setattr (parent_module , current_module_name , expanded_module )
23732428
2374-                 if  current_module_name  in  _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX :
2375-                     attribute_name  =  _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX [current_module_name ]
2376-                     new_value  =  int (expanded_module .weight .data .shape [1 ])
2377-                     old_value  =  getattr (transformer .config , attribute_name )
2378-                     setattr (transformer .config , attribute_name , new_value )
2379-                     logger .info (f"Set the { attribute_name } { new_value } { old_value }  )
2429+         if  expanded_module_names :
2430+             logger .info (
2431+                 f"The following LoRA modules were zero padded to match the state dict of { cls .transformer_name } { expanded_module_names }  
2432+             )
23802433
2381-         return  has_param_with_shape_update 
2434+         return  lora_state_dict 
23822435
23832436
23842437# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially 
0 commit comments