@@ -153,12 +153,9 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
153153                "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." 
154154            )
155155
156-     # Iterate over all LoRA weights. 
157-     all_lora_keys  =  list (state_dict .keys ())
158-     for  key  in  all_lora_keys :
159-         if  not  key .endswith ("lora_down.weight" ):
160-             continue 
161- 
156+     # every down weight has a corresponding up weight and potentially an alpha weight 
157+     lora_keys  =  [k  for  k  in  state_dict .keys () if  k .endswith ("lora_down.weight" )]
158+     for  key  in  lora_keys :
162159        # Extract LoRA name. 
163160        lora_name  =  key .split ("." )[0 ]
164161
@@ -177,9 +174,12 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
177174            # Store DoRA scale if present. 
178175            if  dora_present_in_unet :
179176                dora_scale_key_to_replace  =  "_lora.down."  if  "_lora.down."  in  diffusers_name  else  ".lora.down." 
180-                 unet_state_dict [
181-                     diffusers_name .replace (dora_scale_key_to_replace , ".lora_magnitude_vector." )
182-                 ] =  state_dict .pop (key .replace ("lora_down.weight" , "dora_scale" ))
177+                 new_key  =  diffusers_name .replace (dora_scale_key_to_replace , ".lora_magnitude_vector." )
178+                 # dora_weight = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) 
179+                 dora_weight  =  state_dict .pop (lora_name  +  ".dora_scale" )
180+                 if  dora_weight .dim () <=  2 :
181+                     dora_weight  =  dora_weight .squeeze ()
182+                 unet_state_dict [new_key ] =  dora_weight 
183183
184184        # Handle text encoder LoRAs. 
185185        elif  lora_name .startswith (("lora_te_" , "lora_te1_" , "lora_te2_" )):
@@ -194,18 +194,24 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
194194                te2_state_dict [diffusers_name .replace (".down." , ".up." )] =  state_dict .pop (lora_name_up )
195195
196196            # Store DoRA scale if present. 
197-             if  dora_present_in_te  or  dora_present_in_te2 :
197+             if  ( dora_present_in_te  or  dora_present_in_te2 ) :
198198                dora_scale_key_to_replace_te  =  (
199199                    "_lora.down."  if  "_lora.down."  in  diffusers_name  else  ".lora_linear_layer." 
200200                )
201201                if  lora_name .startswith (("lora_te_" , "lora_te1_" )):
202-                     te_state_dict [
203-                         diffusers_name .replace (dora_scale_key_to_replace_te , ".lora_magnitude_vector." )
204-                     ] =  state_dict .pop (key .replace ("lora_down.weight" , "dora_scale" ))
202+                     new_key  =  diffusers_name .replace (dora_scale_key_to_replace_te , ".lora_magnitude_vector." )
203+                     # dora_weight = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) 
204+                     dora_weight  =  state_dict .pop (lora_name  +  ".dora_scale" )
205+                     if  dora_weight .dim () <=  2 :
206+                         dora_weight  =  dora_weight .squeeze ()
207+                     te_state_dict [new_key ] =  dora_weight 
205208                elif  lora_name .startswith ("lora_te2_" ):
206-                     te2_state_dict [
207-                         diffusers_name .replace (dora_scale_key_to_replace_te , ".lora_magnitude_vector." )
208-                     ] =  state_dict .pop (key .replace ("lora_down.weight" , "dora_scale" ))
209+                     new_key  =  diffusers_name .replace (dora_scale_key_to_replace_te , ".lora_magnitude_vector." )
210+                     # dora_weight = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) 
211+                     dora_weight  =  state_dict .pop (lora_name  +  ".dora_scale" )
212+                     if  dora_weight .dim () <=  2 :
213+                         dora_weight  =  dora_weight .squeeze ()
214+                     te2_state_dict [new_key ] =  dora_weight 
209215
210216        # Store alpha if present. 
211217        if  lora_name_alpha  in  state_dict :
@@ -214,7 +220,8 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
214220
215221    # Check if any keys remain. 
216222    if  len (state_dict ) >  0 :
217-         raise  ValueError (f"The following keys have not been correctly renamed: \n \n  { ', ' .join (state_dict .keys ())}  )
223+         all_keys_remaining  =  sorted (list (state_dict .keys ()))
224+         raise  ValueError (f"The following keys have not been correctly renamed: \n \n  { ', ' .join (all_keys_remaining )}  )
218225
219226    logger .info ("Non-diffusers checkpoint detected." )
220227
@@ -285,7 +292,7 @@ def _convert_unet_lora_key(key):
285292        pass 
286293    else :
287294        pass 
288- 
295+      
289296    return  diffusers_name 
290297
291298
0 commit comments