Skip to content

Commit 634aa8d

Browse files
committed
update
1 parent 6399810 commit 634aa8d

File tree

3 files changed

+18
-2
lines changed

3 files changed

+18
-2
lines changed

src/diffusers/loaders/single_file_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2898,12 +2898,26 @@ def convert_lumina_attn_to_diffusers(tensor, diffusers_key):
28982898
def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
28992899
converted_state_dict = {}
29002900

2901+
keys = list(checkpoint.keys())
2902+
for k in keys:
2903+
if "model.diffusion_model." in k:
2904+
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
2905+
29012906
TRANSFORMER_KEYS_RENAME_DICT = {
29022907
"time_embedding.0": "condition_embedder.time_embedder.linear_1",
29032908
"time_embedding.2": "condition_embedder.time_embedder.linear_2",
29042909
"text_embedding.0": "condition_embedder.text_embedder.linear_1",
29052910
"text_embedding.2": "condition_embedder.text_embedder.linear_2",
29062911
"time_projection.1": "condition_embedder.time_proj",
2912+
"cross_attn": "attn2",
2913+
"self_attn": "attn1",
2914+
".o.": ".to_out.0.",
2915+
".q.": ".to_q.",
2916+
".k.": ".to_k.",
2917+
".v.": ".to_v.",
2918+
".k_img.": ".add_k_proj.",
2919+
".v_img.": ".add_v_proj.",
2920+
".norm_k_img.": ".norm_added_k.",
29072921
"head.modulation": "scale_shift_table",
29082922
"head.head": "proj_out",
29092923
"modulation": "scale_shift_table",

src/diffusers/models/attention_processor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -284,8 +284,9 @@ def __init__(
284284
self.norm_added_q = RMSNorm(dim_head, eps=eps)
285285
self.norm_added_k = RMSNorm(dim_head, eps=eps)
286286
elif qk_norm == "rms_norm_across_heads":
287-
# Wanx applies qk norm across all heads
288-
self.norm_added_q = RMSNorm(dim_head * heads, eps=eps)
287+
# Wan applies qk norm across all heads
288+
# Wan also doesn't apply a q norm
289+
self.norm_added_q = None
289290
self.norm_added_k = RMSNorm(dim_head * kv_heads, eps=eps)
290291
else:
291292
raise ValueError(

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
329329
_skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
330330
_no_split_modules = ["WanTransformerBlock"]
331331
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
332+
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
332333

333334
@register_to_config
334335
def __init__(

0 commit comments

Comments
 (0)