Skip to content

Commit 39d4e0c

Browse files
authored
Merge branch 'main' into post-release-0.33.0
2 parents d81efdc + ea5a6a8 commit 39d4e0c

File tree

12 files changed

+326
-285
lines changed

12 files changed

+326
-285
lines changed

scripts/convert_vae_pt_to_diffusers.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,12 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
5353
}
5454

5555
for i in range(num_down_blocks):
56-
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
56+
resnets = [
57+
key
58+
for key in down_blocks[i]
59+
if f"down.{i}" in key and f"down.{i}.downsample" not in key and "attn" not in key
60+
]
61+
attentions = [key for key in down_blocks[i] if f"down.{i}.attn" in key]
5762

5863
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
5964
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
@@ -67,6 +72,10 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
6772
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
6873
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
6974

75+
paths = renew_vae_attention_paths(attentions)
76+
meta_path = {"old": f"down.{i}.attn", "new": f"down_blocks.{i}.attentions"}
77+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
78+
7079
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
7180
num_mid_res_blocks = 2
7281
for i in range(1, num_mid_res_blocks + 1):
@@ -85,8 +94,11 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
8594
for i in range(num_up_blocks):
8695
block_id = num_up_blocks - 1 - i
8796
resnets = [
88-
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
97+
key
98+
for key in up_blocks[block_id]
99+
if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key and "attn" not in key
89100
]
101+
attentions = [key for key in up_blocks[block_id] if f"up.{block_id}.attn" in key]
90102

91103
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
92104
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
@@ -100,6 +112,10 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
100112
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
101113
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
102114

115+
paths = renew_vae_attention_paths(attentions)
116+
meta_path = {"old": f"up.{block_id}.attn", "new": f"up_blocks.{i}.attentions"}
117+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
118+
103119
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
104120
num_mid_res_blocks = 2
105121
for i in range(1, num_mid_res_blocks + 1):

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1608,3 +1608,64 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
16081608
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
16091609

16101610
return converted_state_dict
1611+
1612+
1613+
def _convert_musubi_wan_lora_to_diffusers(state_dict):
1614+
# https://github.com/kohya-ss/musubi-tuner
1615+
converted_state_dict = {}
1616+
original_state_dict = {k[len("lora_unet_") :]: v for k, v in state_dict.items()}
1617+
1618+
num_blocks = len({k.split("blocks_")[1].split("_")[0] for k in original_state_dict})
1619+
is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
1620+
1621+
def get_alpha_scales(down_weight, key):
1622+
rank = down_weight.shape[0]
1623+
alpha = original_state_dict.pop(key + ".alpha").item()
1624+
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
1625+
scale_down = scale
1626+
scale_up = 1.0
1627+
while scale_down * 2 < scale_up:
1628+
scale_down *= 2
1629+
scale_up /= 2
1630+
return scale_down, scale_up
1631+
1632+
for i in range(num_blocks):
1633+
# Self-attention
1634+
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
1635+
down_weight = original_state_dict.pop(f"blocks_{i}_self_attn_{o}.lora_down.weight")
1636+
up_weight = original_state_dict.pop(f"blocks_{i}_self_attn_{o}.lora_up.weight")
1637+
scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_self_attn_{o}")
1638+
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_A.weight"] = down_weight * scale_down
1639+
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = up_weight * scale_up
1640+
1641+
# Cross-attention
1642+
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
1643+
down_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_down.weight")
1644+
up_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_up.weight")
1645+
scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_cross_attn_{o}")
1646+
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = down_weight * scale_down
1647+
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = up_weight * scale_up
1648+
1649+
if is_i2v_lora:
1650+
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
1651+
down_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_down.weight")
1652+
up_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_up.weight")
1653+
scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_cross_attn_{o}")
1654+
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = down_weight * scale_down
1655+
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = up_weight * scale_up
1656+
1657+
# FFN
1658+
for o, c in zip(["ffn_0", "ffn_2"], ["net.0.proj", "net.2"]):
1659+
down_weight = original_state_dict.pop(f"blocks_{i}_{o}.lora_down.weight")
1660+
up_weight = original_state_dict.pop(f"blocks_{i}_{o}.lora_up.weight")
1661+
scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_{o}")
1662+
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_A.weight"] = down_weight * scale_down
1663+
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.weight"] = up_weight * scale_up
1664+
1665+
if len(original_state_dict) > 0:
1666+
raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}")
1667+
1668+
for key in list(converted_state_dict.keys()):
1669+
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
1670+
1671+
return converted_state_dict

src/diffusers/loaders/lora_pipeline.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
_convert_bfl_flux_control_lora_to_diffusers,
4343
_convert_hunyuan_video_lora_to_diffusers,
4444
_convert_kohya_flux_lora_to_diffusers,
45+
_convert_musubi_wan_lora_to_diffusers,
4546
_convert_non_diffusers_lora_to_diffusers,
4647
_convert_non_diffusers_lumina2_lora_to_diffusers,
4748
_convert_non_diffusers_wan_lora_to_diffusers,
@@ -4794,6 +4795,8 @@ def lora_state_dict(
47944795
)
47954796
if any(k.startswith("diffusion_model.") for k in state_dict):
47964797
state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict)
4798+
elif any(k.startswith("lora_unet_") for k in state_dict):
4799+
state_dict = _convert_musubi_wan_lora_to_diffusers(state_dict)
47974800

47984801
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
47994802
if is_dora_scale_present:

src/diffusers/loaders/single_file_utils.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@
177177
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
178178
"ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"},
179179
"ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"},
180+
"ltx-video-0.9.5": {"pretrained_model_name_or_path": "Lightricks/LTX-Video-0.9.5"},
180181
"autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"},
181182
"autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"},
182183
"autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
@@ -638,7 +639,9 @@ def infer_diffusers_model_type(checkpoint):
638639
model_type = "flux-schnell"
639640

640641
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx-video"]):
641-
if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in checkpoint:
642+
if checkpoint["vae.encoder.conv_out.conv.weight"].shape[1] == 2048:
643+
model_type = "ltx-video-0.9.5"
644+
elif "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in checkpoint:
642645
model_type = "ltx-video-0.9.1"
643646
else:
644647
model_type = "ltx-video"
@@ -2403,13 +2406,41 @@ def remove_keys_(key: str, state_dict):
24032406
"last_scale_shift_table": "scale_shift_table",
24042407
}
24052408

2409+
VAE_095_RENAME_DICT = {
2410+
# decoder
2411+
"up_blocks.0": "mid_block",
2412+
"up_blocks.1": "up_blocks.0.upsamplers.0",
2413+
"up_blocks.2": "up_blocks.0",
2414+
"up_blocks.3": "up_blocks.1.upsamplers.0",
2415+
"up_blocks.4": "up_blocks.1",
2416+
"up_blocks.5": "up_blocks.2.upsamplers.0",
2417+
"up_blocks.6": "up_blocks.2",
2418+
"up_blocks.7": "up_blocks.3.upsamplers.0",
2419+
"up_blocks.8": "up_blocks.3",
2420+
# encoder
2421+
"down_blocks.0": "down_blocks.0",
2422+
"down_blocks.1": "down_blocks.0.downsamplers.0",
2423+
"down_blocks.2": "down_blocks.1",
2424+
"down_blocks.3": "down_blocks.1.downsamplers.0",
2425+
"down_blocks.4": "down_blocks.2",
2426+
"down_blocks.5": "down_blocks.2.downsamplers.0",
2427+
"down_blocks.6": "down_blocks.3",
2428+
"down_blocks.7": "down_blocks.3.downsamplers.0",
2429+
"down_blocks.8": "mid_block",
2430+
# common
2431+
"last_time_embedder": "time_embedder",
2432+
"last_scale_shift_table": "scale_shift_table",
2433+
}
2434+
24062435
VAE_SPECIAL_KEYS_REMAP = {
24072436
"per_channel_statistics.channel": remove_keys_,
24082437
"per_channel_statistics.mean-of-means": remove_keys_,
24092438
"per_channel_statistics.mean-of-stds": remove_keys_,
24102439
}
24112440

2412-
if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict:
2441+
if converted_state_dict["vae.encoder.conv_out.conv.weight"].shape[1] == 2048:
2442+
VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
2443+
elif "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict:
24132444
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
24142445

24152446
for key in list(converted_state_dict.keys()):

src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -350,8 +350,14 @@ def create_vae_diffusers_config(original_config, image_size: int):
350350
_ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"]
351351

352352
block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
353-
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
354-
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
353+
down_block_types = [
354+
"DownEncoderBlock2D" if image_size // 2**i not in vae_params["attn_resolutions"] else "AttnDownEncoderBlock2D"
355+
for i, _ in enumerate(block_out_channels)
356+
]
357+
up_block_types = [
358+
"UpDecoderBlock2D" if image_size // 2**i not in vae_params["attn_resolutions"] else "AttnUpDecoderBlock2D"
359+
for i, _ in enumerate(block_out_channels)
360+
][::-1]
355361

356362
config = {
357363
"sample_size": image_size,

src/diffusers/utils/import_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,18 +101,20 @@ def _is_package_available(pkg_name: str):
101101
if _onnx_available:
102102
candidates = (
103103
"onnxruntime",
104+
"onnxruntime-cann",
105+
"onnxruntime-directml",
106+
"ort_nightly_directml",
104107
"onnxruntime-gpu",
105108
"ort_nightly_gpu",
106-
"onnxruntime-directml",
109+
"onnxruntime-migraphx",
107110
"onnxruntime-openvino",
108-
"ort_nightly_directml",
111+
"onnxruntime-qnn",
109112
"onnxruntime-rocm",
110-
"onnxruntime-migraphx",
111113
"onnxruntime-training",
112114
"onnxruntime-vitisai",
113115
)
114116
_onnxruntime_version = None
115-
# For the metadata, we have to look for both onnxruntime and onnxruntime-gpu
117+
# For the metadata, we have to look for both onnxruntime and onnxruntime-x
116118
for pkg in candidates:
117119
try:
118120
_onnxruntime_version = importlib_metadata.version(pkg)

tests/lora/test_lora_layers_sd.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
)
3434
from diffusers.utils.import_utils import is_accelerate_available
3535
from diffusers.utils.testing_utils import (
36+
Expectations,
3637
backend_empty_cache,
3738
load_image,
3839
nightly,
@@ -455,11 +456,54 @@ def test_vanilla_funetuning(self):
455456

456457
images = pipe("A pokemon with blue eyes.", output_type="np", generator=generator, num_inference_steps=2).images
457458

458-
images = images[0, -3:, -3:, -1].flatten()
459-
460-
expected = np.array([0.7406, 0.699, 0.5963, 0.7493, 0.7045, 0.6096, 0.6886, 0.6388, 0.583])
459+
image_slice = images[0, -3:, -3:, -1].flatten()
460+
461+
expected_slices = Expectations(
462+
{
463+
("xpu", 3): np.array(
464+
[
465+
0.6544,
466+
0.6127,
467+
0.5397,
468+
0.6845,
469+
0.6047,
470+
0.5469,
471+
0.6349,
472+
0.5906,
473+
0.5382,
474+
]
475+
),
476+
("cuda", 7): np.array(
477+
[
478+
0.7406,
479+
0.699,
480+
0.5963,
481+
0.7493,
482+
0.7045,
483+
0.6096,
484+
0.6886,
485+
0.6388,
486+
0.583,
487+
]
488+
),
489+
("cuda", 8): np.array(
490+
[
491+
0.6542,
492+
0.61253,
493+
0.5396,
494+
0.6843,
495+
0.6044,
496+
0.5468,
497+
0.6349,
498+
0.5905,
499+
0.5381,
500+
]
501+
),
502+
}
503+
)
504+
expected_slice = expected_slices.get_expectation()
461505

462-
max_diff = numpy_cosine_similarity_distance(expected, images)
506+
max_diff = numpy_cosine_similarity_distance(expected_slice, image_slice)
463507
assert max_diff < 1e-4
464508

465509
pipe.unload_lora_weights()

0 commit comments

Comments
 (0)