|
120 | 120 | "hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias", |
121 | 121 | "instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight", |
122 | 122 | "lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"], |
123 | | - "z-image-turbo": "cap_embedder.0.weight", |
| 123 | + "z-image-turbo": [ |
| 124 | + "model.diffusion_model.layers.0.adaLN_modulation.0.weight", |
| 125 | + "layers.0.adaLN_modulation.0.weight", |
| 126 | + ], |
124 | 127 | "z-image-turbo-controlnet": "control_all_x_embedder.2-1.weight", |
125 | 128 | "z-image-turbo-controlnet-2.x": "control_layers.14.adaLN_modulation.0.weight", |
126 | 129 | "sana": [ |
@@ -727,10 +730,7 @@ def infer_diffusers_model_type(checkpoint): |
727 | 730 | ): |
728 | 731 | model_type = "instruct-pix2pix" |
729 | 732 |
|
730 | | - elif ( |
731 | | - CHECKPOINT_KEY_NAMES["z-image-turbo"] in checkpoint |
732 | | - and checkpoint[CHECKPOINT_KEY_NAMES["z-image-turbo"]].shape[0] == 2560 |
733 | | - ): |
| 733 | + elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["z-image-turbo"]): |
734 | 734 | model_type = "z-image-turbo" |
735 | 735 |
|
736 | 736 | elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]): |
@@ -3852,6 +3852,7 @@ def convert_z_image_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): |
3852 | 3852 | ".attention.k_norm.weight": ".attention.norm_k.weight", |
3853 | 3853 | ".attention.q_norm.weight": ".attention.norm_q.weight", |
3854 | 3854 | ".attention.out.weight": ".attention.to_out.0.weight", |
| 3855 | + "model.diffusion_model.": "", |
3855 | 3856 | } |
3856 | 3857 |
|
3857 | 3858 | def convert_z_image_fused_attention(key: str, state_dict: dict[str, object]) -> None: |
@@ -3886,6 +3887,9 @@ def update_state_dict(state_dict: dict[str, object], old_key: str, new_key: str) |
3886 | 3887 |
|
3887 | 3888 | update_state_dict(converted_state_dict, key, new_key) |
3888 | 3889 |
|
| 3890 | + if "norm_final.weight" in converted_state_dict.keys(): |
| 3891 | + _ = converted_state_dict.pop("norm_final.weight") |
| 3892 | + |
3889 | 3893 | # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in |
3890 | 3894 | # special_keys_remap |
3891 | 3895 | for key in list(converted_state_dict.keys()): |
|
0 commit comments