diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 791b7ae9b14f..3404f6d91569 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -1608,3 +1608,64 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key) return converted_state_dict + + +def _convert_musubi_wan_lora_to_diffusers(state_dict): + # https://github.com/kohya-ss/musubi-tuner + converted_state_dict = {} + original_state_dict = {k[len("lora_unet_") :]: v for k, v in state_dict.items()} + + num_blocks = len({k.split("blocks_")[1].split("_")[0] for k in original_state_dict}) + is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict) + + def get_alpha_scales(down_weight, key): + rank = down_weight.shape[0] + alpha = original_state_dict.pop(key + ".alpha").item() + scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here + scale_down = scale + scale_up = 1.0 + while scale_down * 2 < scale_up: + scale_down *= 2 + scale_up /= 2 + return scale_down, scale_up + + for i in range(num_blocks): + # Self-attention + for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]): + down_weight = original_state_dict.pop(f"blocks_{i}_self_attn_{o}.lora_down.weight") + up_weight = original_state_dict.pop(f"blocks_{i}_self_attn_{o}.lora_up.weight") + scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_self_attn_{o}") + converted_state_dict[f"blocks.{i}.attn1.{c}.lora_A.weight"] = down_weight * scale_down + converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = up_weight * scale_up + + # Cross-attention + for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]): + down_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_down.weight") + up_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_up.weight") + scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_cross_attn_{o}") + converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = down_weight * scale_down + converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = up_weight * scale_up + + if is_i2v_lora: + for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]): + down_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_down.weight") + up_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_up.weight") + scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_cross_attn_{o}") + converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = down_weight * scale_down + converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = up_weight * scale_up + + # FFN + for o, c in zip(["ffn_0", "ffn_2"], ["net.0.proj", "net.2"]): + down_weight = original_state_dict.pop(f"blocks_{i}_{o}.lora_down.weight") + up_weight = original_state_dict.pop(f"blocks_{i}_{o}.lora_up.weight") + scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_{o}") + converted_state_dict[f"blocks.{i}.ffn.{c}.lora_A.weight"] = down_weight * scale_down + converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.weight"] = up_weight * scale_up + + if len(original_state_dict) > 0: + raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}") + + for key in list(converted_state_dict.keys()): + converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key) + + return converted_state_dict diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index a29b77acce6e..2e241bc9ffad 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -42,6 +42,7 @@ _convert_bfl_flux_control_lora_to_diffusers, _convert_hunyuan_video_lora_to_diffusers, _convert_kohya_flux_lora_to_diffusers, + _convert_musubi_wan_lora_to_diffusers, _convert_non_diffusers_lora_to_diffusers, _convert_non_diffusers_lumina2_lora_to_diffusers, _convert_non_diffusers_wan_lora_to_diffusers, @@ -4794,6 +4795,8 @@ def lora_state_dict( ) if any(k.startswith("diffusion_model.") for k in state_dict): state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict) + elif any(k.startswith("lora_unet_") for k in state_dict): + state_dict = _convert_musubi_wan_lora_to_diffusers(state_dict) is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: