@@ -142,10 +142,10 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
142
142
network_alphas = {}
143
143
144
144
# Check for DoRA-enabled LoRAs.
145
- if any (
146
- "dora_scale" in k and ("lora_unet_" in k or " lora_te_" in k or "lora_te1_" in k or "lora_te2_" in k )
147
- for k in state_dict
148
- ) :
145
+ dora_present_in_unet = any ("dora_scale" in k and "lora_unet_" in k for k in state_dict )
146
+ dora_present_in_te = any ( "dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k ) for k in state_dict )
147
+ dora_present_in_te2 = any ( "dora_scale" in k and "lora_te2_" in k for k in state_dict )
148
+ if dora_present_in_unet or dora_present_in_te or dora_present_in_te2 :
149
149
if is_peft_version ("<" , "0.9.0" ):
150
150
raise ValueError (
151
151
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
@@ -173,7 +173,7 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
173
173
unet_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict .pop (lora_name_up )
174
174
175
175
# Store DoRA scale if present.
176
- if "dora_scale" in state_dict :
176
+ if dora_present_in_unet :
177
177
dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
178
178
unet_state_dict [
179
179
diffusers_name .replace (dora_scale_key_to_replace , ".lora_magnitude_vector." )
@@ -192,7 +192,7 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
192
192
te2_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict .pop (lora_name_up )
193
193
194
194
# Store DoRA scale if present.
195
- if "dora_scale" in state_dict :
195
+ if dora_present_in_te or dora_present_in_te2 :
196
196
dora_scale_key_to_replace_te = (
197
197
"_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
198
198
)
@@ -214,7 +214,7 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
214
214
if len (state_dict ) > 0 :
215
215
raise ValueError (f"The following keys have not been correctly renamed: \n \n { ', ' .join (state_dict .keys ())} " )
216
216
217
- logger .info ("Kohya-style checkpoint detected." )
217
+ logger .info ("Non-diffusers checkpoint detected." )
218
218
219
219
# Construct final state dict.
220
220
unet_state_dict = {f"{ unet_name } .{ module_name } " : params for module_name , params in unet_state_dict .items ()}
0 commit comments