Skip to content

Commit e068d10

Browse files
authored
fix wan umt5 state dict converter (#170)
1 parent 63188ac commit e068d10

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

diffsynth_engine/models/wan/wan_text_encoder.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -198,22 +198,22 @@ def __init__(self, num_encoder_layers: int = 24):
198198

199199
def _from_diffusers(self, state_dict):
200200
rename_dict = {
201-
"enc.output_norm.weight": "norm.weight",
202-
"token_embd.weight": "token_embedding.weight",
201+
"shared.weight": "token_embedding.weight",
202+
"encoder.final_layer_norm.weight": "norm.weight",
203203
}
204204
for i in range(self.num_encoder_layers):
205205
rename_dict.update(
206206
{
207-
f"enc.blk.{i}.attn_q.weight": f"blocks.{i}.attn.q.weight",
208-
f"enc.blk.{i}.attn_k.weight": f"blocks.{i}.attn.k.weight",
209-
f"enc.blk.{i}.attn_v.weight": f"blocks.{i}.attn.v.weight",
210-
f"enc.blk.{i}.attn_o.weight": f"blocks.{i}.attn.o.weight",
211-
f"enc.blk.{i}.ffn_up.weight": f"blocks.{i}.ffn.fc1.weight",
212-
f"enc.blk.{i}.ffn_down.weight": f"blocks.{i}.ffn.fc2.weight",
213-
f"enc.blk.{i}.ffn_gate.weight": f"blocks.{i}.ffn.gate.0.weight",
214-
f"enc.blk.{i}.attn_norm.weight": f"blocks.{i}.norm1.weight",
215-
f"enc.blk.{i}.ffn_norm.weight": f"blocks.{i}.norm2.weight",
216-
f"enc.blk.{i}.attn_rel_b.weight": f"blocks.{i}.pos_embedding.embedding.weight",
207+
f"encoder.block.{i}.layer.0.SelfAttention.q.weight": f"blocks.{i}.attn.q.weight",
208+
f"encoder.block.{i}.layer.0.SelfAttention.k.weight": f"blocks.{i}.attn.k.weight",
209+
f"encoder.block.{i}.layer.0.SelfAttention.v.weight": f"blocks.{i}.attn.v.weight",
210+
f"encoder.block.{i}.layer.0.SelfAttention.o.weight": f"blocks.{i}.attn.o.weight",
211+
f"encoder.block.{i}.layer.0.SelfAttention.relative_attention_bias.weight": f"blocks.{i}.pos_embedding.embedding.weight",
212+
f"encoder.block.{i}.layer.0.layer_norm.weight": f"blocks.{i}.norm1.weight",
213+
f"encoder.block.{i}.layer.1.DenseReluDense.wi_0.weight": f"blocks.{i}.ffn.gate.0.weight",
214+
f"encoder.block.{i}.layer.1.DenseReluDense.wi_1.weight": f"blocks.{i}.ffn.fc1.weight",
215+
f"encoder.block.{i}.layer.1.DenseReluDense.wo.weight": f"blocks.{i}.ffn.fc2.weight",
216+
f"encoder.block.{i}.layer.1.layer_norm.weight": f"blocks.{i}.norm2.weight",
217217
}
218218
)
219219

@@ -224,7 +224,7 @@ def _from_diffusers(self, state_dict):
224224
return new_state_dict
225225

226226
def convert(self, state_dict):
227-
if "enc.output_norm.weight" in state_dict:
227+
if "encoder.final_layer_norm.weight" in state_dict:
228228
logger.info("use diffusers format state dict")
229229
return self._from_diffusers(state_dict)
230230
return state_dict

0 commit comments

Comments
 (0)