Skip to content

Commit 27dcb5f

Browse files
authored
fix flux param (#89)
* fix flux param * add comment
1 parent aa6eefc commit 27dcb5f

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

diffsynth_engine/models/flux/flux_dit.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ def _from_diffusers(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torc
4343
suffix = ".weight" if name.endswith(".weight") else ".bias"
4444
prefix = name[: -len(suffix)]
4545
if prefix in global_rename_dict:
46+
# Fix load diffusers format weights [issue](https://github.com/modelscope/DiffSynth-Engine/issues/90).
47+
if prefix.startswith("norm_out.linear"):
48+
param = torch.concat([param[dim:], param[:dim]], dim=0)
4649
state_dict_[global_rename_dict[prefix] + suffix] = param
4750
elif prefix.startswith("transformer_blocks."):
4851
names = prefix.split(".")
@@ -82,7 +85,7 @@ def _from_diffusers(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torc
8285
dim=0,
8386
)
8487
state_dict_[name.replace(".proj_in_besides_attn.", ".attn.to_qkv.")] = param
85-
state_dict_[name.replace(".proj_in_besides_attn.", ".mlp.0.")] = state_dict_[name_]
88+
state_dict_[name.replace(".proj_in_besides_attn.", ".mlp.0.")] = state_dict_[name]
8689
state_dict_.pop(name.replace(".proj_in_besides_attn.", ".a_to_q."))
8790
state_dict_.pop(name.replace(".proj_in_besides_attn.", ".a_to_k."))
8891
state_dict_.pop(name.replace(".proj_in_besides_attn.", ".a_to_v."))

0 commit comments

Comments
 (0)