diff --git a/networks/sdxl_merge_lora.py b/networks/sdxl_merge_lora.py index b147eb446..a465c7c6a 100644 --- a/networks/sdxl_merge_lora.py +++ b/networks/sdxl_merge_lora.py @@ -4,6 +4,7 @@ import os import time import concurrent.futures +import re import torch from safetensors.torch import load_file, save_file from tqdm import tqdm @@ -110,15 +111,56 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, lbws, logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}") if method == "LoRA": + def convert_diffusers_labels_to_unet(name: str) -> str: + if "_attentions_" not in name: + return name # attention-schema names only + + # Normalize stage tokens everywhere, not only for attention names + s = (name.replace("unet_up", "unet_output") + .replace("unet_down", "unet_input") + .replace("unet_mid", "unet_middle")) + + # Middle: ...middle_block_attentions_X_* -> ...middle_block_{X+1}_* + if "unet_middle" in s and "middle_block_attentions_" in s: + return re.sub( + r"middle_block_attentions_(\d+)_", + lambda m: f"middle_block_{int(m.group(1)) + 1}_", + s, + ) + + left, right = s.split("_attentions_", 1) + L = left.split("_") # lora_unet_[input|output]_blocks_{X} + stage = L[2] + X = int(L[-1]) # down/up block index in "wrong" + Y_str, *rest = right.split("_") + Y = int(Y_str) # attentions index in "wrong" + # Remainder includes transformer/proj tail (kept as-is) + # Map to "right" indices: + if stage == "input": # from "down" + if X == 1: i = 4 + Y # → input_blocks_{4|5}_1 + elif X == 2: i = 7 + Y # → input_blocks_{7|8}_1 + else: return s # no attentions elsewhere + j = 1 + elif stage == "output": # from "up" + i = 3 * X + Y # db=0→0..2, db=1→3..5, db=2→6..8 + j = 1 + else: + return s + + L[-1] = str(i) + return "_".join(L + [str(j)] + rest) + for key in tqdm(lora_sd.keys()): if "lora_down" in key: - up_key = key.replace("lora_down", "lora_up") - alpha_key = key[: key.index("lora_down")] + "alpha" + key_base = key[: key.index("lora_down")] + up_key = key_base + "lora_up.weight" + dora_key = key_base + "dora_scale" + alpha_key = key_base + "alpha" # find original module for this lora - module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" + module_name = ".".join(convert_diffusers_labels_to_unet(key).split(".")[:-2]) # remove trailing ".lora_down.weight" if module_name not in name_to_module: - logger.info(f"no module found for LoRA weight: {key}") + logger.info(f"no module found for LoRA weight: {module_name}, from({key})") continue module = name_to_module[module_name] # logger.info(f"apply {key} to {module}") @@ -127,7 +169,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, lbws, up_weight = lora_sd[up_key] dim = down_weight.size()[0] - alpha = lora_sd.get(alpha_key, dim) + alpha = lora_sd.get(alpha_key, 1.0) scale = alpha / dim if lbw: @@ -138,23 +180,62 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, lbws, # W <- W + U * D weight = module.weight + lora_diff = None # logger.info(module_name, down_weight.size(), up_weight.size()) if len(weight.size()) == 2: # linear - weight = weight + ratio * (up_weight @ down_weight) * scale + lora_diff = (up_weight @ down_weight) elif down_weight.size()[2:4] == (1, 1): # conv2d 1x1 - weight = ( - weight - + ratio - * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) - * scale - ) + lora_diff = (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) else: # conv2d 3x3 - conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + lora_diff = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) # logger.info(conved.size(), weight.size(), module.stride, module.padding) - weight = weight + ratio * conved * scale + + dora_scale = lora_sd.get(dora_key, None) + + # Algorithm/math taken from reForge + if dora_scale is None: + # -------- Plain LoRA (mirror your original math) -------- + # W <- W + ratio * (lora_diff * scale) + weight = weight + (ratio * (lora_diff * scale)).to(dtype=weight.dtype, device=weight.device) + + else: + # -------- DoRA (literal reForge semantics) -------- + # cast dora_scale like reForge does (to intermediate, then we use weight.dtype for ops) + ds = dora_scale.to(device=weight.device, dtype=merge_dtype) + + # lora_diff gets 'alpha' (scale == alpha/rank) BEFORE magnitude; strength applied AFTER magnitude + lora_diff_scaled = (lora_diff * scale).to(dtype=weight.dtype, device=weight.device) + + # weight_calc = weight + function(lora_diff_scaled); function is identity here + weight_calc = weight + lora_diff_scaled + + wd_on_output_axis = (ds.shape[0] == weight_calc.shape[0]) + if wd_on_output_axis: + # per-OUT norm taken from ORIGINAL weight (matches reForge) + weight_norm = ( + weight.reshape(weight.shape[0], -1) + .norm(dim=1, keepdim=True) + .reshape(weight.shape[0], *[1] * (weight.dim() - 1)) + ) + else: + # per-IN norm from weight_calc^T (matches reForge) + wc = weight_calc.transpose(0, 1) + weight_norm = ( + wc.reshape(wc.shape[0], -1) + .norm(dim=1, keepdim=True) + .reshape(wc.shape[0], *[1] * (wc.dim() - 1)) + .transpose(0, 1) + ) + + weight_norm = weight_norm + torch.finfo(weight.dtype).eps + + # Apply magnitude: weight_calc *= (dora_scale / weight_norm) + # (Do NOT reshape ds; rely on its stored shape for broadcasting) + weight_calc = weight_calc * (ds.to(dtype=weight.dtype) / weight_norm) + weight = torch.lerp(weight, weight_calc, float(ratio)) module.weight = torch.nn.Parameter(weight)