@@ -123,134 +123,76 @@ def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", b
123
123
return new_state_dict
124
124
125
125
126
- def _convert_kohya_lora_to_diffusers (state_dict , unet_name = "unet" , text_encoder_name = "text_encoder" ):
126
+ def _convert_non_diffusers_lora_to_diffusers (state_dict , unet_name = "unet" , text_encoder_name = "text_encoder" ):
127
+ """
128
+ Converts a non-Diffusers LoRA state dict to a Diffusers compatible state dict.
129
+
130
+ Args:
131
+ state_dict (`dict`): The state dict to convert.
132
+ unet_name (`str`, optional): The name of the U-Net module in the Diffusers model. Defaults to "unet".
133
+ text_encoder_name (`str`, optional): The name of the text encoder module in the Diffusers model. Defaults to
134
+ "text_encoder".
135
+
136
+ Returns:
137
+ `tuple`: A tuple containing the converted state dict and a dictionary of alphas.
138
+ """
127
139
unet_state_dict = {}
128
140
te_state_dict = {}
129
141
te2_state_dict = {}
130
142
network_alphas = {}
131
- is_unet_dora_lora = any ("dora_scale" in k and "lora_unet_" in k for k in state_dict )
132
- is_te_dora_lora = any ("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k ) for k in state_dict )
133
- is_te2_dora_lora = any ("dora_scale" in k and "lora_te2_" in k for k in state_dict )
134
143
135
- if is_unet_dora_lora or is_te_dora_lora or is_te2_dora_lora :
144
+ # Check for DoRA-enabled LoRAs.
145
+ if any (
146
+ "dora_scale" in k and ("lora_unet_" in k or "lora_te_" in k or "lora_te1_" in k or "lora_te2_" in k )
147
+ for k in state_dict
148
+ ):
136
149
if is_peft_version ("<" , "0.9.0" ):
137
150
raise ValueError (
138
151
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
139
152
)
140
153
141
- # every down weight has a corresponding up weight and potentially an alpha weight
142
- lora_keys = [k for k in state_dict .keys () if k .endswith ("lora_down.weight" )]
143
- for key in lora_keys :
154
+ # Iterate over all LoRA weights.
155
+ all_lora_keys = list (state_dict .keys ())
156
+ for key in all_lora_keys :
157
+ if not key .endswith ("lora_down.weight" ):
158
+ continue
159
+
160
+ # Extract LoRA name.
144
161
lora_name = key .split ("." )[0 ]
162
+
163
+ # Find corresponding up weight and alpha.
145
164
lora_name_up = lora_name + ".lora_up.weight"
146
165
lora_name_alpha = lora_name + ".alpha"
147
166
167
+ # Handle U-Net LoRAs.
148
168
if lora_name .startswith ("lora_unet_" ):
149
- diffusers_name = key .replace ("lora_unet_" , "" ).replace ("_" , "." )
150
-
151
- if "input.blocks" in diffusers_name :
152
- diffusers_name = diffusers_name .replace ("input.blocks" , "down_blocks" )
153
- else :
154
- diffusers_name = diffusers_name .replace ("down.blocks" , "down_blocks" )
169
+ diffusers_name = _convert_unet_lora_key (key )
155
170
156
- if "middle.block" in diffusers_name :
157
- diffusers_name = diffusers_name .replace ("middle.block" , "mid_block" )
158
- else :
159
- diffusers_name = diffusers_name .replace ("mid.block" , "mid_block" )
160
- if "output.blocks" in diffusers_name :
161
- diffusers_name = diffusers_name .replace ("output.blocks" , "up_blocks" )
162
- else :
163
- diffusers_name = diffusers_name .replace ("up.blocks" , "up_blocks" )
164
-
165
- diffusers_name = diffusers_name .replace ("transformer.blocks" , "transformer_blocks" )
166
- diffusers_name = diffusers_name .replace ("to.q.lora" , "to_q_lora" )
167
- diffusers_name = diffusers_name .replace ("to.k.lora" , "to_k_lora" )
168
- diffusers_name = diffusers_name .replace ("to.v.lora" , "to_v_lora" )
169
- diffusers_name = diffusers_name .replace ("to.out.0.lora" , "to_out_lora" )
170
- diffusers_name = diffusers_name .replace ("proj.in" , "proj_in" )
171
- diffusers_name = diffusers_name .replace ("proj.out" , "proj_out" )
172
- diffusers_name = diffusers_name .replace ("emb.layers" , "time_emb_proj" )
173
-
174
- # SDXL specificity.
175
- if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name :
176
- pattern = r"\.\d+(?=\D*$)"
177
- diffusers_name = re .sub (pattern , "" , diffusers_name , count = 1 )
178
- if ".in." in diffusers_name :
179
- diffusers_name = diffusers_name .replace ("in.layers.2" , "conv1" )
180
- if ".out." in diffusers_name :
181
- diffusers_name = diffusers_name .replace ("out.layers.3" , "conv2" )
182
- if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name :
183
- diffusers_name = diffusers_name .replace ("op" , "conv" )
184
- if "skip" in diffusers_name :
185
- diffusers_name = diffusers_name .replace ("skip.connection" , "conv_shortcut" )
186
-
187
- # LyCORIS specificity.
188
- if "time.emb.proj" in diffusers_name :
189
- diffusers_name = diffusers_name .replace ("time.emb.proj" , "time_emb_proj" )
190
- if "conv.shortcut" in diffusers_name :
191
- diffusers_name = diffusers_name .replace ("conv.shortcut" , "conv_shortcut" )
192
-
193
- # General coverage.
194
- if "transformer_blocks" in diffusers_name :
195
- if "attn1" in diffusers_name or "attn2" in diffusers_name :
196
- diffusers_name = diffusers_name .replace ("attn1" , "attn1.processor" )
197
- diffusers_name = diffusers_name .replace ("attn2" , "attn2.processor" )
198
- unet_state_dict [diffusers_name ] = state_dict .pop (key )
199
- unet_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict .pop (lora_name_up )
200
- elif "ff" in diffusers_name :
201
- unet_state_dict [diffusers_name ] = state_dict .pop (key )
202
- unet_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict .pop (lora_name_up )
203
- elif any (key in diffusers_name for key in ("proj_in" , "proj_out" )):
204
- unet_state_dict [diffusers_name ] = state_dict .pop (key )
205
- unet_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict .pop (lora_name_up )
206
- else :
207
- unet_state_dict [diffusers_name ] = state_dict .pop (key )
208
- unet_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict .pop (lora_name_up )
171
+ # Store down and up weights.
172
+ unet_state_dict [diffusers_name ] = state_dict .pop (key )
173
+ unet_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict .pop (lora_name_up )
209
174
210
- if is_unet_dora_lora :
175
+ # Store DoRA scale if present.
176
+ if "dora_scale" in state_dict :
211
177
dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
212
178
unet_state_dict [
213
179
diffusers_name .replace (dora_scale_key_to_replace , ".lora_magnitude_vector." )
214
180
] = state_dict .pop (key .replace ("lora_down.weight" , "dora_scale" ))
215
181
182
+ # Handle text encoder LoRAs.
216
183
elif lora_name .startswith (("lora_te_" , "lora_te1_" , "lora_te2_" )):
184
+ diffusers_name = _convert_text_encoder_lora_key (key , lora_name )
185
+
186
+ # Store down and up weights for te or te2.
217
187
if lora_name .startswith (("lora_te_" , "lora_te1_" )):
218
- key_to_replace = "lora_te_" if lora_name .startswith ("lora_te_" ) else "lora_te1_"
188
+ te_state_dict [diffusers_name ] = state_dict .pop (key )
189
+ te_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict .pop (lora_name_up )
219
190
else :
220
- key_to_replace = "lora_te2_"
221
-
222
- diffusers_name = key .replace (key_to_replace , "" ).replace ("_" , "." )
223
- diffusers_name = diffusers_name .replace ("text.model" , "text_model" )
224
- diffusers_name = diffusers_name .replace ("self.attn" , "self_attn" )
225
- diffusers_name = diffusers_name .replace ("q.proj.lora" , "to_q_lora" )
226
- diffusers_name = diffusers_name .replace ("k.proj.lora" , "to_k_lora" )
227
- diffusers_name = diffusers_name .replace ("v.proj.lora" , "to_v_lora" )
228
- diffusers_name = diffusers_name .replace ("out.proj.lora" , "to_out_lora" )
229
- diffusers_name = diffusers_name .replace ("text.projection" , "text_projection" )
230
-
231
- if "self_attn" in diffusers_name :
232
- if lora_name .startswith (("lora_te_" , "lora_te1_" )):
233
- te_state_dict [diffusers_name ] = state_dict .pop (key )
234
- te_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict .pop (lora_name_up )
235
- else :
236
- te2_state_dict [diffusers_name ] = state_dict .pop (key )
237
- te2_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict .pop (lora_name_up )
238
- elif "mlp" in diffusers_name :
239
- # Be aware that this is the new diffusers convention and the rest of the code might
240
- # not utilize it yet.
241
- diffusers_name = diffusers_name .replace (".lora." , ".lora_linear_layer." )
242
- if lora_name .startswith (("lora_te_" , "lora_te1_" )):
243
- te_state_dict [diffusers_name ] = state_dict .pop (key )
244
- te_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict .pop (lora_name_up )
245
- else :
246
- te2_state_dict [diffusers_name ] = state_dict .pop (key )
247
- te2_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict .pop (lora_name_up )
248
- # OneTrainer specificity
249
- elif "text_projection" in diffusers_name and lora_name .startswith ("lora_te2_" ):
250
191
te2_state_dict [diffusers_name ] = state_dict .pop (key )
251
192
te2_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict .pop (lora_name_up )
252
193
253
- if (is_te_dora_lora or is_te2_dora_lora ) and lora_name .startswith (("lora_te_" , "lora_te1_" , "lora_te2_" )):
194
+ # Store DoRA scale if present.
195
+ if "dora_scale" in state_dict :
254
196
dora_scale_key_to_replace_te = (
255
197
"_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
256
198
)
@@ -263,22 +205,18 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
263
205
diffusers_name .replace (dora_scale_key_to_replace_te , ".lora_magnitude_vector." )
264
206
] = state_dict .pop (key .replace ("lora_down.weight" , "dora_scale" ))
265
207
266
- # Rename the alphas so that they can be mapped appropriately .
208
+ # Store alpha if present .
267
209
if lora_name_alpha in state_dict :
268
210
alpha = state_dict .pop (lora_name_alpha ).item ()
269
- if lora_name_alpha .startswith ("lora_unet_" ):
270
- prefix = "unet."
271
- elif lora_name_alpha .startswith (("lora_te_" , "lora_te1_" )):
272
- prefix = "text_encoder."
273
- else :
274
- prefix = "text_encoder_2."
275
- new_name = prefix + diffusers_name .split (".lora." )[0 ] + ".alpha"
276
- network_alphas .update ({new_name : alpha })
211
+ network_alphas .update (_get_alpha_name (lora_name_alpha , diffusers_name , alpha ))
277
212
213
+ # Check if any keys remain.
278
214
if len (state_dict ) > 0 :
279
215
raise ValueError (f"The following keys have not been correctly renamed: \n \n { ', ' .join (state_dict .keys ())} " )
280
216
281
217
logger .info ("Kohya-style checkpoint detected." )
218
+
219
+ # Construct final state dict.
282
220
unet_state_dict = {f"{ unet_name } .{ module_name } " : params for module_name , params in unet_state_dict .items ()}
283
221
te_state_dict = {f"{ text_encoder_name } .{ module_name } " : params for module_name , params in te_state_dict .items ()}
284
222
te2_state_dict = (
@@ -291,3 +229,100 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
291
229
292
230
new_state_dict = {** unet_state_dict , ** te_state_dict }
293
231
return new_state_dict , network_alphas
232
+
233
+
234
+ def _convert_unet_lora_key (key ):
235
+ """
236
+ Converts a U-Net LoRA key to a Diffusers compatible key.
237
+ """
238
+ diffusers_name = key .replace ("lora_unet_" , "" ).replace ("_" , "." )
239
+
240
+ # Replace common U-Net naming patterns.
241
+ diffusers_name = diffusers_name .replace ("input.blocks" , "down_blocks" )
242
+ diffusers_name = diffusers_name .replace ("down.blocks" , "down_blocks" )
243
+ diffusers_name = diffusers_name .replace ("middle.block" , "mid_block" )
244
+ diffusers_name = diffusers_name .replace ("mid.block" , "mid_block" )
245
+ diffusers_name = diffusers_name .replace ("output.blocks" , "up_blocks" )
246
+ diffusers_name = diffusers_name .replace ("up.blocks" , "up_blocks" )
247
+ diffusers_name = diffusers_name .replace ("transformer.blocks" , "transformer_blocks" )
248
+ diffusers_name = diffusers_name .replace ("to.q.lora" , "to_q_lora" )
249
+ diffusers_name = diffusers_name .replace ("to.k.lora" , "to_k_lora" )
250
+ diffusers_name = diffusers_name .replace ("to.v.lora" , "to_v_lora" )
251
+ diffusers_name = diffusers_name .replace ("to.out.0.lora" , "to_out_lora" )
252
+ diffusers_name = diffusers_name .replace ("proj.in" , "proj_in" )
253
+ diffusers_name = diffusers_name .replace ("proj.out" , "proj_out" )
254
+ diffusers_name = diffusers_name .replace ("emb.layers" , "time_emb_proj" )
255
+
256
+ # SDXL specific conversions.
257
+ if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name :
258
+ pattern = r"\.\d+(?=\D*$)"
259
+ diffusers_name = re .sub (pattern , "" , diffusers_name , count = 1 )
260
+ if ".in." in diffusers_name :
261
+ diffusers_name = diffusers_name .replace ("in.layers.2" , "conv1" )
262
+ if ".out." in diffusers_name :
263
+ diffusers_name = diffusers_name .replace ("out.layers.3" , "conv2" )
264
+ if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name :
265
+ diffusers_name = diffusers_name .replace ("op" , "conv" )
266
+ if "skip" in diffusers_name :
267
+ diffusers_name = diffusers_name .replace ("skip.connection" , "conv_shortcut" )
268
+
269
+ # LyCORIS specific conversions.
270
+ if "time.emb.proj" in diffusers_name :
271
+ diffusers_name = diffusers_name .replace ("time.emb.proj" , "time_emb_proj" )
272
+ if "conv.shortcut" in diffusers_name :
273
+ diffusers_name = diffusers_name .replace ("conv.shortcut" , "conv_shortcut" )
274
+
275
+ # General conversions.
276
+ if "transformer_blocks" in diffusers_name :
277
+ if "attn1" in diffusers_name or "attn2" in diffusers_name :
278
+ diffusers_name = diffusers_name .replace ("attn1" , "attn1.processor" )
279
+ diffusers_name = diffusers_name .replace ("attn2" , "attn2.processor" )
280
+ elif "ff" in diffusers_name :
281
+ pass
282
+ elif any (key in diffusers_name for key in ("proj_in" , "proj_out" )):
283
+ pass
284
+ else :
285
+ pass
286
+
287
+ return diffusers_name
288
+
289
+
290
+ def _convert_text_encoder_lora_key (key , lora_name ):
291
+ """
292
+ Converts a text encoder LoRA key to a Diffusers compatible key.
293
+ """
294
+ if lora_name .startswith (("lora_te_" , "lora_te1_" )):
295
+ key_to_replace = "lora_te_" if lora_name .startswith ("lora_te_" ) else "lora_te1_"
296
+ else :
297
+ key_to_replace = "lora_te2_"
298
+
299
+ diffusers_name = key .replace (key_to_replace , "" ).replace ("_" , "." )
300
+ diffusers_name = diffusers_name .replace ("text.model" , "text_model" )
301
+ diffusers_name = diffusers_name .replace ("self.attn" , "self_attn" )
302
+ diffusers_name = diffusers_name .replace ("q.proj.lora" , "to_q_lora" )
303
+ diffusers_name = diffusers_name .replace ("k.proj.lora" , "to_k_lora" )
304
+ diffusers_name = diffusers_name .replace ("v.proj.lora" , "to_v_lora" )
305
+ diffusers_name = diffusers_name .replace ("out.proj.lora" , "to_out_lora" )
306
+ diffusers_name = diffusers_name .replace ("text.projection" , "text_projection" )
307
+
308
+ if "self_attn" in diffusers_name or "text_projection" in diffusers_name :
309
+ pass
310
+ elif "mlp" in diffusers_name :
311
+ # Be aware that this is the new diffusers convention and the rest of the code might
312
+ # not utilize it yet.
313
+ diffusers_name = diffusers_name .replace (".lora." , ".lora_linear_layer." )
314
+ return diffusers_name
315
+
316
+
317
+ def _get_alpha_name (lora_name_alpha , diffusers_name , alpha ):
318
+ """
319
+ Gets the correct alpha name for the Diffusers model.
320
+ """
321
+ if lora_name_alpha .startswith ("lora_unet_" ):
322
+ prefix = "unet."
323
+ elif lora_name_alpha .startswith (("lora_te_" , "lora_te1_" )):
324
+ prefix = "text_encoder."
325
+ else :
326
+ prefix = "text_encoder_2."
327
+ new_name = prefix + diffusers_name .split (".lora." )[0 ] + ".alpha"
328
+ return {new_name : alpha }
0 commit comments