Skip to content

Commit 8403860

Browse files
hlkyAki-07
authored andcommitted
Z-Image-Turbo from_single_file fix (huggingface#12888)
1 parent cc9f7d3 commit 8403860

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

src/diffusers/loaders/single_file_utils.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,10 @@
120120
"hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
121121
"instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
122122
"lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"],
123-
"z-image-turbo": "cap_embedder.0.weight",
123+
"z-image-turbo": [
124+
"model.diffusion_model.layers.0.adaLN_modulation.0.weight",
125+
"layers.0.adaLN_modulation.0.weight",
126+
],
124127
"z-image-turbo-controlnet": "control_all_x_embedder.2-1.weight",
125128
"z-image-turbo-controlnet-2.x": "control_layers.14.adaLN_modulation.0.weight",
126129
"sana": [
@@ -727,10 +730,7 @@ def infer_diffusers_model_type(checkpoint):
727730
):
728731
model_type = "instruct-pix2pix"
729732

730-
elif (
731-
CHECKPOINT_KEY_NAMES["z-image-turbo"] in checkpoint
732-
and checkpoint[CHECKPOINT_KEY_NAMES["z-image-turbo"]].shape[0] == 2560
733-
):
733+
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["z-image-turbo"]):
734734
model_type = "z-image-turbo"
735735

736736
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]):
@@ -3852,6 +3852,7 @@ def convert_z_image_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
38523852
".attention.k_norm.weight": ".attention.norm_k.weight",
38533853
".attention.q_norm.weight": ".attention.norm_q.weight",
38543854
".attention.out.weight": ".attention.to_out.0.weight",
3855+
"model.diffusion_model.": "",
38553856
}
38563857

38573858
def convert_z_image_fused_attention(key: str, state_dict: dict[str, object]) -> None:
@@ -3886,6 +3887,9 @@ def update_state_dict(state_dict: dict[str, object], old_key: str, new_key: str)
38863887

38873888
update_state_dict(converted_state_dict, key, new_key)
38883889

3890+
if "norm_final.weight" in converted_state_dict.keys():
3891+
_ = converted_state_dict.pop("norm_final.weight")
3892+
38893893
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
38903894
# special_keys_remap
38913895
for key in list(converted_state_dict.keys()):

0 commit comments

Comments
 (0)