Skip to content

Commit a9c403c

Browse files
authored
[LoRA] refactor lora conversion utility. (#8295)
* refactor lora conversion utility. * remove error raises. * add onetrainer support too.
1 parent e7b9a07 commit a9c403c

File tree

2 files changed

+146
-111
lines changed

2 files changed

+146
-111
lines changed

src/diffusers/loaders/lora.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
set_adapter_layers,
4444
set_weights_and_activate_adapters,
4545
)
46-
from .lora_conversion_utils import _convert_kohya_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers
46+
from .lora_conversion_utils import _convert_non_diffusers_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers
4747

4848

4949
if is_transformers_available():
@@ -288,7 +288,7 @@ def lora_state_dict(
288288
if unet_config is not None:
289289
# use unet config to remap block numbers
290290
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
291-
state_dict, network_alphas = _convert_kohya_lora_to_diffusers(state_dict)
291+
state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
292292

293293
return state_dict, network_alphas
294294

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 144 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -123,134 +123,76 @@ def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", b
123123
return new_state_dict
124124

125125

126-
def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_name="text_encoder"):
126+
def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_name="text_encoder"):
127+
"""
128+
Converts a non-Diffusers LoRA state dict to a Diffusers compatible state dict.
129+
130+
Args:
131+
state_dict (`dict`): The state dict to convert.
132+
unet_name (`str`, optional): The name of the U-Net module in the Diffusers model. Defaults to "unet".
133+
text_encoder_name (`str`, optional): The name of the text encoder module in the Diffusers model. Defaults to
134+
"text_encoder".
135+
136+
Returns:
137+
`tuple`: A tuple containing the converted state dict and a dictionary of alphas.
138+
"""
127139
unet_state_dict = {}
128140
te_state_dict = {}
129141
te2_state_dict = {}
130142
network_alphas = {}
131-
is_unet_dora_lora = any("dora_scale" in k and "lora_unet_" in k for k in state_dict)
132-
is_te_dora_lora = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict)
133-
is_te2_dora_lora = any("dora_scale" in k and "lora_te2_" in k for k in state_dict)
134143

135-
if is_unet_dora_lora or is_te_dora_lora or is_te2_dora_lora:
144+
# Check for DoRA-enabled LoRAs.
145+
if any(
146+
"dora_scale" in k and ("lora_unet_" in k or "lora_te_" in k or "lora_te1_" in k or "lora_te2_" in k)
147+
for k in state_dict
148+
):
136149
if is_peft_version("<", "0.9.0"):
137150
raise ValueError(
138151
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
139152
)
140153

141-
# every down weight has a corresponding up weight and potentially an alpha weight
142-
lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")]
143-
for key in lora_keys:
154+
# Iterate over all LoRA weights.
155+
all_lora_keys = list(state_dict.keys())
156+
for key in all_lora_keys:
157+
if not key.endswith("lora_down.weight"):
158+
continue
159+
160+
# Extract LoRA name.
144161
lora_name = key.split(".")[0]
162+
163+
# Find corresponding up weight and alpha.
145164
lora_name_up = lora_name + ".lora_up.weight"
146165
lora_name_alpha = lora_name + ".alpha"
147166

167+
# Handle U-Net LoRAs.
148168
if lora_name.startswith("lora_unet_"):
149-
diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
150-
151-
if "input.blocks" in diffusers_name:
152-
diffusers_name = diffusers_name.replace("input.blocks", "down_blocks")
153-
else:
154-
diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
169+
diffusers_name = _convert_unet_lora_key(key)
155170

156-
if "middle.block" in diffusers_name:
157-
diffusers_name = diffusers_name.replace("middle.block", "mid_block")
158-
else:
159-
diffusers_name = diffusers_name.replace("mid.block", "mid_block")
160-
if "output.blocks" in diffusers_name:
161-
diffusers_name = diffusers_name.replace("output.blocks", "up_blocks")
162-
else:
163-
diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
164-
165-
diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
166-
diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
167-
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
168-
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
169-
diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
170-
diffusers_name = diffusers_name.replace("proj.in", "proj_in")
171-
diffusers_name = diffusers_name.replace("proj.out", "proj_out")
172-
diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")
173-
174-
# SDXL specificity.
175-
if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name:
176-
pattern = r"\.\d+(?=\D*$)"
177-
diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
178-
if ".in." in diffusers_name:
179-
diffusers_name = diffusers_name.replace("in.layers.2", "conv1")
180-
if ".out." in diffusers_name:
181-
diffusers_name = diffusers_name.replace("out.layers.3", "conv2")
182-
if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name:
183-
diffusers_name = diffusers_name.replace("op", "conv")
184-
if "skip" in diffusers_name:
185-
diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")
186-
187-
# LyCORIS specificity.
188-
if "time.emb.proj" in diffusers_name:
189-
diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj")
190-
if "conv.shortcut" in diffusers_name:
191-
diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut")
192-
193-
# General coverage.
194-
if "transformer_blocks" in diffusers_name:
195-
if "attn1" in diffusers_name or "attn2" in diffusers_name:
196-
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
197-
diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
198-
unet_state_dict[diffusers_name] = state_dict.pop(key)
199-
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
200-
elif "ff" in diffusers_name:
201-
unet_state_dict[diffusers_name] = state_dict.pop(key)
202-
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
203-
elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
204-
unet_state_dict[diffusers_name] = state_dict.pop(key)
205-
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
206-
else:
207-
unet_state_dict[diffusers_name] = state_dict.pop(key)
208-
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
171+
# Store down and up weights.
172+
unet_state_dict[diffusers_name] = state_dict.pop(key)
173+
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
209174

210-
if is_unet_dora_lora:
175+
# Store DoRA scale if present.
176+
if "dora_scale" in state_dict:
211177
dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
212178
unet_state_dict[
213179
diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
214180
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
215181

182+
# Handle text encoder LoRAs.
216183
elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
184+
diffusers_name = _convert_text_encoder_lora_key(key, lora_name)
185+
186+
# Store down and up weights for te or te2.
217187
if lora_name.startswith(("lora_te_", "lora_te1_")):
218-
key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_"
188+
te_state_dict[diffusers_name] = state_dict.pop(key)
189+
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
219190
else:
220-
key_to_replace = "lora_te2_"
221-
222-
diffusers_name = key.replace(key_to_replace, "").replace("_", ".")
223-
diffusers_name = diffusers_name.replace("text.model", "text_model")
224-
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
225-
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
226-
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
227-
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
228-
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
229-
diffusers_name = diffusers_name.replace("text.projection", "text_projection")
230-
231-
if "self_attn" in diffusers_name:
232-
if lora_name.startswith(("lora_te_", "lora_te1_")):
233-
te_state_dict[diffusers_name] = state_dict.pop(key)
234-
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
235-
else:
236-
te2_state_dict[diffusers_name] = state_dict.pop(key)
237-
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
238-
elif "mlp" in diffusers_name:
239-
# Be aware that this is the new diffusers convention and the rest of the code might
240-
# not utilize it yet.
241-
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
242-
if lora_name.startswith(("lora_te_", "lora_te1_")):
243-
te_state_dict[diffusers_name] = state_dict.pop(key)
244-
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
245-
else:
246-
te2_state_dict[diffusers_name] = state_dict.pop(key)
247-
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
248-
# OneTrainer specificity
249-
elif "text_projection" in diffusers_name and lora_name.startswith("lora_te2_"):
250191
te2_state_dict[diffusers_name] = state_dict.pop(key)
251192
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
252193

253-
if (is_te_dora_lora or is_te2_dora_lora) and lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
194+
# Store DoRA scale if present.
195+
if "dora_scale" in state_dict:
254196
dora_scale_key_to_replace_te = (
255197
"_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
256198
)
@@ -263,22 +205,18 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
263205
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
264206
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
265207

266-
# Rename the alphas so that they can be mapped appropriately.
208+
# Store alpha if present.
267209
if lora_name_alpha in state_dict:
268210
alpha = state_dict.pop(lora_name_alpha).item()
269-
if lora_name_alpha.startswith("lora_unet_"):
270-
prefix = "unet."
271-
elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")):
272-
prefix = "text_encoder."
273-
else:
274-
prefix = "text_encoder_2."
275-
new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
276-
network_alphas.update({new_name: alpha})
211+
network_alphas.update(_get_alpha_name(lora_name_alpha, diffusers_name, alpha))
277212

213+
# Check if any keys remain.
278214
if len(state_dict) > 0:
279215
raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}")
280216

281217
logger.info("Kohya-style checkpoint detected.")
218+
219+
# Construct final state dict.
282220
unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}
283221
te_state_dict = {f"{text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()}
284222
te2_state_dict = (
@@ -291,3 +229,100 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
291229

292230
new_state_dict = {**unet_state_dict, **te_state_dict}
293231
return new_state_dict, network_alphas
232+
233+
234+
def _convert_unet_lora_key(key):
235+
"""
236+
Converts a U-Net LoRA key to a Diffusers compatible key.
237+
"""
238+
diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
239+
240+
# Replace common U-Net naming patterns.
241+
diffusers_name = diffusers_name.replace("input.blocks", "down_blocks")
242+
diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
243+
diffusers_name = diffusers_name.replace("middle.block", "mid_block")
244+
diffusers_name = diffusers_name.replace("mid.block", "mid_block")
245+
diffusers_name = diffusers_name.replace("output.blocks", "up_blocks")
246+
diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
247+
diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
248+
diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
249+
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
250+
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
251+
diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
252+
diffusers_name = diffusers_name.replace("proj.in", "proj_in")
253+
diffusers_name = diffusers_name.replace("proj.out", "proj_out")
254+
diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")
255+
256+
# SDXL specific conversions.
257+
if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name:
258+
pattern = r"\.\d+(?=\D*$)"
259+
diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
260+
if ".in." in diffusers_name:
261+
diffusers_name = diffusers_name.replace("in.layers.2", "conv1")
262+
if ".out." in diffusers_name:
263+
diffusers_name = diffusers_name.replace("out.layers.3", "conv2")
264+
if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name:
265+
diffusers_name = diffusers_name.replace("op", "conv")
266+
if "skip" in diffusers_name:
267+
diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")
268+
269+
# LyCORIS specific conversions.
270+
if "time.emb.proj" in diffusers_name:
271+
diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj")
272+
if "conv.shortcut" in diffusers_name:
273+
diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut")
274+
275+
# General conversions.
276+
if "transformer_blocks" in diffusers_name:
277+
if "attn1" in diffusers_name or "attn2" in diffusers_name:
278+
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
279+
diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
280+
elif "ff" in diffusers_name:
281+
pass
282+
elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
283+
pass
284+
else:
285+
pass
286+
287+
return diffusers_name
288+
289+
290+
def _convert_text_encoder_lora_key(key, lora_name):
291+
"""
292+
Converts a text encoder LoRA key to a Diffusers compatible key.
293+
"""
294+
if lora_name.startswith(("lora_te_", "lora_te1_")):
295+
key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_"
296+
else:
297+
key_to_replace = "lora_te2_"
298+
299+
diffusers_name = key.replace(key_to_replace, "").replace("_", ".")
300+
diffusers_name = diffusers_name.replace("text.model", "text_model")
301+
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
302+
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
303+
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
304+
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
305+
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
306+
diffusers_name = diffusers_name.replace("text.projection", "text_projection")
307+
308+
if "self_attn" in diffusers_name or "text_projection" in diffusers_name:
309+
pass
310+
elif "mlp" in diffusers_name:
311+
# Be aware that this is the new diffusers convention and the rest of the code might
312+
# not utilize it yet.
313+
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
314+
return diffusers_name
315+
316+
317+
def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):
318+
"""
319+
Gets the correct alpha name for the Diffusers model.
320+
"""
321+
if lora_name_alpha.startswith("lora_unet_"):
322+
prefix = "unet."
323+
elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")):
324+
prefix = "text_encoder."
325+
else:
326+
prefix = "text_encoder_2."
327+
new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
328+
return {new_name: alpha}

0 commit comments

Comments
 (0)