| 
 | 1 | +import argparse  | 
 | 2 | + | 
 | 3 | +import torch  | 
 | 4 | + | 
 | 5 | +from diffusers import HunyuanDiT2DControlNetModel  | 
 | 6 | + | 
 | 7 | + | 
 | 8 | +def main(args):  | 
 | 9 | +    state_dict = torch.load(args.pt_checkpoint_path, map_location="cpu")  | 
 | 10 | + | 
 | 11 | +    if args.load_key != "none":  | 
 | 12 | +        try:  | 
 | 13 | +            state_dict = state_dict[args.load_key]  | 
 | 14 | +        except KeyError:  | 
 | 15 | +            raise KeyError(  | 
 | 16 | +                f"{args.load_key} not found in the checkpoint."  | 
 | 17 | +                "Please load from the following keys:{state_dict.keys()}"  | 
 | 18 | +            )  | 
 | 19 | +    device = "cuda"  | 
 | 20 | + | 
 | 21 | +    model_config = HunyuanDiT2DControlNetModel.load_config(  | 
 | 22 | +        "Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", subfolder="transformer"  | 
 | 23 | +    )  | 
 | 24 | +    model_config[  | 
 | 25 | +        "use_style_cond_and_image_meta_size"  | 
 | 26 | +    ] = args.use_style_cond_and_image_meta_size  ### version <= v1.1: True; version >= v1.2: False  | 
 | 27 | +    print(model_config)  | 
 | 28 | + | 
 | 29 | +    for key in state_dict:  | 
 | 30 | +        print("local:", key)  | 
 | 31 | + | 
 | 32 | +    model = HunyuanDiT2DControlNetModel.from_config(model_config).to(device)  | 
 | 33 | + | 
 | 34 | +    for key in model.state_dict():  | 
 | 35 | +        print("diffusers:", key)  | 
 | 36 | + | 
 | 37 | +    num_layers = 19  | 
 | 38 | +    for i in range(num_layers):  | 
 | 39 | +        # attn1  | 
 | 40 | +        # Wkqv -> to_q, to_k, to_v  | 
 | 41 | +        q, k, v = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.weight"], 3, dim=0)  | 
 | 42 | +        q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.bias"], 3, dim=0)  | 
 | 43 | +        state_dict[f"blocks.{i}.attn1.to_q.weight"] = q  | 
 | 44 | +        state_dict[f"blocks.{i}.attn1.to_q.bias"] = q_bias  | 
 | 45 | +        state_dict[f"blocks.{i}.attn1.to_k.weight"] = k  | 
 | 46 | +        state_dict[f"blocks.{i}.attn1.to_k.bias"] = k_bias  | 
 | 47 | +        state_dict[f"blocks.{i}.attn1.to_v.weight"] = v  | 
 | 48 | +        state_dict[f"blocks.{i}.attn1.to_v.bias"] = v_bias  | 
 | 49 | +        state_dict.pop(f"blocks.{i}.attn1.Wqkv.weight")  | 
 | 50 | +        state_dict.pop(f"blocks.{i}.attn1.Wqkv.bias")  | 
 | 51 | + | 
 | 52 | +        # q_norm, k_norm -> norm_q, norm_k  | 
 | 53 | +        state_dict[f"blocks.{i}.attn1.norm_q.weight"] = state_dict[f"blocks.{i}.attn1.q_norm.weight"]  | 
 | 54 | +        state_dict[f"blocks.{i}.attn1.norm_q.bias"] = state_dict[f"blocks.{i}.attn1.q_norm.bias"]  | 
 | 55 | +        state_dict[f"blocks.{i}.attn1.norm_k.weight"] = state_dict[f"blocks.{i}.attn1.k_norm.weight"]  | 
 | 56 | +        state_dict[f"blocks.{i}.attn1.norm_k.bias"] = state_dict[f"blocks.{i}.attn1.k_norm.bias"]  | 
 | 57 | + | 
 | 58 | +        state_dict.pop(f"blocks.{i}.attn1.q_norm.weight")  | 
 | 59 | +        state_dict.pop(f"blocks.{i}.attn1.q_norm.bias")  | 
 | 60 | +        state_dict.pop(f"blocks.{i}.attn1.k_norm.weight")  | 
 | 61 | +        state_dict.pop(f"blocks.{i}.attn1.k_norm.bias")  | 
 | 62 | + | 
 | 63 | +        # out_proj -> to_out  | 
 | 64 | +        state_dict[f"blocks.{i}.attn1.to_out.0.weight"] = state_dict[f"blocks.{i}.attn1.out_proj.weight"]  | 
 | 65 | +        state_dict[f"blocks.{i}.attn1.to_out.0.bias"] = state_dict[f"blocks.{i}.attn1.out_proj.bias"]  | 
 | 66 | +        state_dict.pop(f"blocks.{i}.attn1.out_proj.weight")  | 
 | 67 | +        state_dict.pop(f"blocks.{i}.attn1.out_proj.bias")  | 
 | 68 | + | 
 | 69 | +        # attn2  | 
 | 70 | +        # kq_proj -> to_k, to_v  | 
 | 71 | +        k, v = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.weight"], 2, dim=0)  | 
 | 72 | +        k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.bias"], 2, dim=0)  | 
 | 73 | +        state_dict[f"blocks.{i}.attn2.to_k.weight"] = k  | 
 | 74 | +        state_dict[f"blocks.{i}.attn2.to_k.bias"] = k_bias  | 
 | 75 | +        state_dict[f"blocks.{i}.attn2.to_v.weight"] = v  | 
 | 76 | +        state_dict[f"blocks.{i}.attn2.to_v.bias"] = v_bias  | 
 | 77 | +        state_dict.pop(f"blocks.{i}.attn2.kv_proj.weight")  | 
 | 78 | +        state_dict.pop(f"blocks.{i}.attn2.kv_proj.bias")  | 
 | 79 | + | 
 | 80 | +        # q_proj -> to_q  | 
 | 81 | +        state_dict[f"blocks.{i}.attn2.to_q.weight"] = state_dict[f"blocks.{i}.attn2.q_proj.weight"]  | 
 | 82 | +        state_dict[f"blocks.{i}.attn2.to_q.bias"] = state_dict[f"blocks.{i}.attn2.q_proj.bias"]  | 
 | 83 | +        state_dict.pop(f"blocks.{i}.attn2.q_proj.weight")  | 
 | 84 | +        state_dict.pop(f"blocks.{i}.attn2.q_proj.bias")  | 
 | 85 | + | 
 | 86 | +        # q_norm, k_norm -> norm_q, norm_k  | 
 | 87 | +        state_dict[f"blocks.{i}.attn2.norm_q.weight"] = state_dict[f"blocks.{i}.attn2.q_norm.weight"]  | 
 | 88 | +        state_dict[f"blocks.{i}.attn2.norm_q.bias"] = state_dict[f"blocks.{i}.attn2.q_norm.bias"]  | 
 | 89 | +        state_dict[f"blocks.{i}.attn2.norm_k.weight"] = state_dict[f"blocks.{i}.attn2.k_norm.weight"]  | 
 | 90 | +        state_dict[f"blocks.{i}.attn2.norm_k.bias"] = state_dict[f"blocks.{i}.attn2.k_norm.bias"]  | 
 | 91 | + | 
 | 92 | +        state_dict.pop(f"blocks.{i}.attn2.q_norm.weight")  | 
 | 93 | +        state_dict.pop(f"blocks.{i}.attn2.q_norm.bias")  | 
 | 94 | +        state_dict.pop(f"blocks.{i}.attn2.k_norm.weight")  | 
 | 95 | +        state_dict.pop(f"blocks.{i}.attn2.k_norm.bias")  | 
 | 96 | + | 
 | 97 | +        # out_proj -> to_out  | 
 | 98 | +        state_dict[f"blocks.{i}.attn2.to_out.0.weight"] = state_dict[f"blocks.{i}.attn2.out_proj.weight"]  | 
 | 99 | +        state_dict[f"blocks.{i}.attn2.to_out.0.bias"] = state_dict[f"blocks.{i}.attn2.out_proj.bias"]  | 
 | 100 | +        state_dict.pop(f"blocks.{i}.attn2.out_proj.weight")  | 
 | 101 | +        state_dict.pop(f"blocks.{i}.attn2.out_proj.bias")  | 
 | 102 | + | 
 | 103 | +        # switch norm 2 and norm 3  | 
 | 104 | +        norm2_weight = state_dict[f"blocks.{i}.norm2.weight"]  | 
 | 105 | +        norm2_bias = state_dict[f"blocks.{i}.norm2.bias"]  | 
 | 106 | +        state_dict[f"blocks.{i}.norm2.weight"] = state_dict[f"blocks.{i}.norm3.weight"]  | 
 | 107 | +        state_dict[f"blocks.{i}.norm2.bias"] = state_dict[f"blocks.{i}.norm3.bias"]  | 
 | 108 | +        state_dict[f"blocks.{i}.norm3.weight"] = norm2_weight  | 
 | 109 | +        state_dict[f"blocks.{i}.norm3.bias"] = norm2_bias  | 
 | 110 | + | 
 | 111 | +        # norm1 -> norm1.norm  | 
 | 112 | +        # default_modulation.1 -> norm1.linear  | 
 | 113 | +        state_dict[f"blocks.{i}.norm1.norm.weight"] = state_dict[f"blocks.{i}.norm1.weight"]  | 
 | 114 | +        state_dict[f"blocks.{i}.norm1.norm.bias"] = state_dict[f"blocks.{i}.norm1.bias"]  | 
 | 115 | +        state_dict[f"blocks.{i}.norm1.linear.weight"] = state_dict[f"blocks.{i}.default_modulation.1.weight"]  | 
 | 116 | +        state_dict[f"blocks.{i}.norm1.linear.bias"] = state_dict[f"blocks.{i}.default_modulation.1.bias"]  | 
 | 117 | +        state_dict.pop(f"blocks.{i}.norm1.weight")  | 
 | 118 | +        state_dict.pop(f"blocks.{i}.norm1.bias")  | 
 | 119 | +        state_dict.pop(f"blocks.{i}.default_modulation.1.weight")  | 
 | 120 | +        state_dict.pop(f"blocks.{i}.default_modulation.1.bias")  | 
 | 121 | + | 
 | 122 | +        # mlp.fc1 -> ff.net.0, mlp.fc2 -> ff.net.2  | 
 | 123 | +        state_dict[f"blocks.{i}.ff.net.0.proj.weight"] = state_dict[f"blocks.{i}.mlp.fc1.weight"]  | 
 | 124 | +        state_dict[f"blocks.{i}.ff.net.0.proj.bias"] = state_dict[f"blocks.{i}.mlp.fc1.bias"]  | 
 | 125 | +        state_dict[f"blocks.{i}.ff.net.2.weight"] = state_dict[f"blocks.{i}.mlp.fc2.weight"]  | 
 | 126 | +        state_dict[f"blocks.{i}.ff.net.2.bias"] = state_dict[f"blocks.{i}.mlp.fc2.bias"]  | 
 | 127 | +        state_dict.pop(f"blocks.{i}.mlp.fc1.weight")  | 
 | 128 | +        state_dict.pop(f"blocks.{i}.mlp.fc1.bias")  | 
 | 129 | +        state_dict.pop(f"blocks.{i}.mlp.fc2.weight")  | 
 | 130 | +        state_dict.pop(f"blocks.{i}.mlp.fc2.bias")  | 
 | 131 | + | 
 | 132 | +        # after_proj_list -> controlnet_blocks  | 
 | 133 | +        state_dict[f"controlnet_blocks.{i}.weight"] = state_dict[f"after_proj_list.{i}.weight"]  | 
 | 134 | +        state_dict[f"controlnet_blocks.{i}.bias"] = state_dict[f"after_proj_list.{i}.bias"]  | 
 | 135 | +        state_dict.pop(f"after_proj_list.{i}.weight")  | 
 | 136 | +        state_dict.pop(f"after_proj_list.{i}.bias")  | 
 | 137 | + | 
 | 138 | +    # before_proj -> input_block  | 
 | 139 | +    state_dict["input_block.weight"] = state_dict["before_proj.weight"]  | 
 | 140 | +    state_dict["input_block.bias"] = state_dict["before_proj.bias"]  | 
 | 141 | +    state_dict.pop("before_proj.weight")  | 
 | 142 | +    state_dict.pop("before_proj.bias")  | 
 | 143 | + | 
 | 144 | +    # pooler -> time_extra_emb  | 
 | 145 | +    state_dict["time_extra_emb.pooler.positional_embedding"] = state_dict["pooler.positional_embedding"]  | 
 | 146 | +    state_dict["time_extra_emb.pooler.k_proj.weight"] = state_dict["pooler.k_proj.weight"]  | 
 | 147 | +    state_dict["time_extra_emb.pooler.k_proj.bias"] = state_dict["pooler.k_proj.bias"]  | 
 | 148 | +    state_dict["time_extra_emb.pooler.q_proj.weight"] = state_dict["pooler.q_proj.weight"]  | 
 | 149 | +    state_dict["time_extra_emb.pooler.q_proj.bias"] = state_dict["pooler.q_proj.bias"]  | 
 | 150 | +    state_dict["time_extra_emb.pooler.v_proj.weight"] = state_dict["pooler.v_proj.weight"]  | 
 | 151 | +    state_dict["time_extra_emb.pooler.v_proj.bias"] = state_dict["pooler.v_proj.bias"]  | 
 | 152 | +    state_dict["time_extra_emb.pooler.c_proj.weight"] = state_dict["pooler.c_proj.weight"]  | 
 | 153 | +    state_dict["time_extra_emb.pooler.c_proj.bias"] = state_dict["pooler.c_proj.bias"]  | 
 | 154 | +    state_dict.pop("pooler.k_proj.weight")  | 
 | 155 | +    state_dict.pop("pooler.k_proj.bias")  | 
 | 156 | +    state_dict.pop("pooler.q_proj.weight")  | 
 | 157 | +    state_dict.pop("pooler.q_proj.bias")  | 
 | 158 | +    state_dict.pop("pooler.v_proj.weight")  | 
 | 159 | +    state_dict.pop("pooler.v_proj.bias")  | 
 | 160 | +    state_dict.pop("pooler.c_proj.weight")  | 
 | 161 | +    state_dict.pop("pooler.c_proj.bias")  | 
 | 162 | +    state_dict.pop("pooler.positional_embedding")  | 
 | 163 | + | 
 | 164 | +    # t_embedder -> time_embedding (`TimestepEmbedding`)  | 
 | 165 | +    state_dict["time_extra_emb.timestep_embedder.linear_1.bias"] = state_dict["t_embedder.mlp.0.bias"]  | 
 | 166 | +    state_dict["time_extra_emb.timestep_embedder.linear_1.weight"] = state_dict["t_embedder.mlp.0.weight"]  | 
 | 167 | +    state_dict["time_extra_emb.timestep_embedder.linear_2.bias"] = state_dict["t_embedder.mlp.2.bias"]  | 
 | 168 | +    state_dict["time_extra_emb.timestep_embedder.linear_2.weight"] = state_dict["t_embedder.mlp.2.weight"]  | 
 | 169 | + | 
 | 170 | +    state_dict.pop("t_embedder.mlp.0.bias")  | 
 | 171 | +    state_dict.pop("t_embedder.mlp.0.weight")  | 
 | 172 | +    state_dict.pop("t_embedder.mlp.2.bias")  | 
 | 173 | +    state_dict.pop("t_embedder.mlp.2.weight")  | 
 | 174 | + | 
 | 175 | +    # x_embedder -> pos_embd (`PatchEmbed`)  | 
 | 176 | +    state_dict["pos_embed.proj.weight"] = state_dict["x_embedder.proj.weight"]  | 
 | 177 | +    state_dict["pos_embed.proj.bias"] = state_dict["x_embedder.proj.bias"]  | 
 | 178 | +    state_dict.pop("x_embedder.proj.weight")  | 
 | 179 | +    state_dict.pop("x_embedder.proj.bias")  | 
 | 180 | + | 
 | 181 | +    # mlp_t5 -> text_embedder  | 
 | 182 | +    state_dict["text_embedder.linear_1.bias"] = state_dict["mlp_t5.0.bias"]  | 
 | 183 | +    state_dict["text_embedder.linear_1.weight"] = state_dict["mlp_t5.0.weight"]  | 
 | 184 | +    state_dict["text_embedder.linear_2.bias"] = state_dict["mlp_t5.2.bias"]  | 
 | 185 | +    state_dict["text_embedder.linear_2.weight"] = state_dict["mlp_t5.2.weight"]  | 
 | 186 | +    state_dict.pop("mlp_t5.0.bias")  | 
 | 187 | +    state_dict.pop("mlp_t5.0.weight")  | 
 | 188 | +    state_dict.pop("mlp_t5.2.bias")  | 
 | 189 | +    state_dict.pop("mlp_t5.2.weight")  | 
 | 190 | + | 
 | 191 | +    # extra_embedder -> extra_embedder  | 
 | 192 | +    state_dict["time_extra_emb.extra_embedder.linear_1.bias"] = state_dict["extra_embedder.0.bias"]  | 
 | 193 | +    state_dict["time_extra_emb.extra_embedder.linear_1.weight"] = state_dict["extra_embedder.0.weight"]  | 
 | 194 | +    state_dict["time_extra_emb.extra_embedder.linear_2.bias"] = state_dict["extra_embedder.2.bias"]  | 
 | 195 | +    state_dict["time_extra_emb.extra_embedder.linear_2.weight"] = state_dict["extra_embedder.2.weight"]  | 
 | 196 | +    state_dict.pop("extra_embedder.0.bias")  | 
 | 197 | +    state_dict.pop("extra_embedder.0.weight")  | 
 | 198 | +    state_dict.pop("extra_embedder.2.bias")  | 
 | 199 | +    state_dict.pop("extra_embedder.2.weight")  | 
 | 200 | + | 
 | 201 | +    # style_embedder  | 
 | 202 | +    if model_config["use_style_cond_and_image_meta_size"]:  | 
 | 203 | +        print(state_dict["style_embedder.weight"])  | 
 | 204 | +        print(state_dict["style_embedder.weight"].shape)  | 
 | 205 | +        state_dict["time_extra_emb.style_embedder.weight"] = state_dict["style_embedder.weight"][0:1]  | 
 | 206 | +        state_dict.pop("style_embedder.weight")  | 
 | 207 | + | 
 | 208 | +    model.load_state_dict(state_dict)  | 
 | 209 | + | 
 | 210 | +    if args.save:  | 
 | 211 | +        model.save_pretrained(args.output_checkpoint_path)  | 
 | 212 | + | 
 | 213 | + | 
 | 214 | +if __name__ == "__main__":  | 
 | 215 | +    parser = argparse.ArgumentParser()  | 
 | 216 | + | 
 | 217 | +    parser.add_argument(  | 
 | 218 | +        "--save", default=True, type=bool, required=False, help="Whether to save the converted pipeline or not."  | 
 | 219 | +    )  | 
 | 220 | +    parser.add_argument(  | 
 | 221 | +        "--pt_checkpoint_path", default=None, type=str, required=True, help="Path to the .pt pretrained model."  | 
 | 222 | +    )  | 
 | 223 | +    parser.add_argument(  | 
 | 224 | +        "--output_checkpoint_path",  | 
 | 225 | +        default=None,  | 
 | 226 | +        type=str,  | 
 | 227 | +        required=False,  | 
 | 228 | +        help="Path to the output converted diffusers pipeline.",  | 
 | 229 | +    )  | 
 | 230 | +    parser.add_argument(  | 
 | 231 | +        "--load_key", default="none", type=str, required=False, help="The key to load from the pretrained .pt file"  | 
 | 232 | +    )  | 
 | 233 | +    parser.add_argument(  | 
 | 234 | +        "--use_style_cond_and_image_meta_size",  | 
 | 235 | +        type=bool,  | 
 | 236 | +        default=False,  | 
 | 237 | +        help="version <= v1.1: True; version >= v1.2: False",  | 
 | 238 | +    )  | 
 | 239 | + | 
 | 240 | +    args = parser.parse_args()  | 
 | 241 | +    main(args)  | 
0 commit comments