|  | 
| 106 | 106 |     ], | 
| 107 | 107 |     "autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias", | 
| 108 | 108 |     "autoencoder-dc-sana": "encoder.project_in.conv.bias", | 
|  | 109 | +    "mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"], | 
| 109 | 110 | } | 
| 110 | 111 | 
 | 
| 111 | 112 | DIFFUSERS_DEFAULT_PIPELINE_PATHS = { | 
|  | 
| 159 | 160 |     "autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"}, | 
| 160 | 161 |     "autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"}, | 
| 161 | 162 |     "autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"}, | 
|  | 163 | +    "mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"}, | 
| 162 | 164 | } | 
| 163 | 165 | 
 | 
| 164 | 166 | # Use to configure model sample size when original config is provided | 
| @@ -618,6 +620,9 @@ def infer_diffusers_model_type(checkpoint): | 
| 618 | 620 |         else: | 
| 619 | 621 |             model_type = "autoencoder-dc-f128c512" | 
| 620 | 622 | 
 | 
|  | 623 | +    elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["mochi-1-preview"]): | 
|  | 624 | +        model_type = "mochi-1-preview" | 
|  | 625 | + | 
| 621 | 626 |     else: | 
| 622 | 627 |         model_type = "v1" | 
| 623 | 628 | 
 | 
| @@ -1758,6 +1763,12 @@ def swap_scale_shift(weight, dim): | 
| 1758 | 1763 |     return new_weight | 
| 1759 | 1764 | 
 | 
| 1760 | 1765 | 
 | 
|  | 1766 | +def swap_proj_gate(weight): | 
|  | 1767 | +    proj, gate = weight.chunk(2, dim=0) | 
|  | 1768 | +    new_weight = torch.cat([gate, proj], dim=0) | 
|  | 1769 | +    return new_weight | 
|  | 1770 | + | 
|  | 1771 | + | 
| 1761 | 1772 | def get_attn2_layers(state_dict): | 
| 1762 | 1773 |     attn2_layers = [] | 
| 1763 | 1774 |     for key in state_dict.keys(): | 
| @@ -2414,3 +2425,101 @@ def remap_proj_conv_(key: str, state_dict): | 
| 2414 | 2425 |             handler_fn_inplace(key, converted_state_dict) | 
| 2415 | 2426 | 
 | 
| 2416 | 2427 |     return converted_state_dict | 
|  | 2428 | + | 
|  | 2429 | + | 
|  | 2430 | +def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): | 
|  | 2431 | +    new_state_dict = {} | 
|  | 2432 | + | 
|  | 2433 | +    # Comfy checkpoints add this prefix | 
|  | 2434 | +    keys = list(checkpoint.keys()) | 
|  | 2435 | +    for k in keys: | 
|  | 2436 | +        if "model.diffusion_model." in k: | 
|  | 2437 | +            checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) | 
|  | 2438 | + | 
|  | 2439 | +    # Convert patch_embed | 
|  | 2440 | +    new_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight") | 
|  | 2441 | +    new_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias") | 
|  | 2442 | + | 
|  | 2443 | +    # Convert time_embed | 
|  | 2444 | +    new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight") | 
|  | 2445 | +    new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias") | 
|  | 2446 | +    new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight") | 
|  | 2447 | +    new_state_dict["time_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias") | 
|  | 2448 | +    new_state_dict["time_embed.pooler.to_kv.weight"] = checkpoint.pop("t5_y_embedder.to_kv.weight") | 
|  | 2449 | +    new_state_dict["time_embed.pooler.to_kv.bias"] = checkpoint.pop("t5_y_embedder.to_kv.bias") | 
|  | 2450 | +    new_state_dict["time_embed.pooler.to_q.weight"] = checkpoint.pop("t5_y_embedder.to_q.weight") | 
|  | 2451 | +    new_state_dict["time_embed.pooler.to_q.bias"] = checkpoint.pop("t5_y_embedder.to_q.bias") | 
|  | 2452 | +    new_state_dict["time_embed.pooler.to_out.weight"] = checkpoint.pop("t5_y_embedder.to_out.weight") | 
|  | 2453 | +    new_state_dict["time_embed.pooler.to_out.bias"] = checkpoint.pop("t5_y_embedder.to_out.bias") | 
|  | 2454 | +    new_state_dict["time_embed.caption_proj.weight"] = checkpoint.pop("t5_yproj.weight") | 
|  | 2455 | +    new_state_dict["time_embed.caption_proj.bias"] = checkpoint.pop("t5_yproj.bias") | 
|  | 2456 | + | 
|  | 2457 | +    # Convert transformer blocks | 
|  | 2458 | +    num_layers = 48 | 
|  | 2459 | +    for i in range(num_layers): | 
|  | 2460 | +        block_prefix = f"transformer_blocks.{i}." | 
|  | 2461 | +        old_prefix = f"blocks.{i}." | 
|  | 2462 | + | 
|  | 2463 | +        # norm1 | 
|  | 2464 | +        new_state_dict[block_prefix + "norm1.linear.weight"] = checkpoint.pop(old_prefix + "mod_x.weight") | 
|  | 2465 | +        new_state_dict[block_prefix + "norm1.linear.bias"] = checkpoint.pop(old_prefix + "mod_x.bias") | 
|  | 2466 | +        if i < num_layers - 1: | 
|  | 2467 | +            new_state_dict[block_prefix + "norm1_context.linear.weight"] = checkpoint.pop(old_prefix + "mod_y.weight") | 
|  | 2468 | +            new_state_dict[block_prefix + "norm1_context.linear.bias"] = checkpoint.pop(old_prefix + "mod_y.bias") | 
|  | 2469 | +        else: | 
|  | 2470 | +            new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = checkpoint.pop( | 
|  | 2471 | +                old_prefix + "mod_y.weight" | 
|  | 2472 | +            ) | 
|  | 2473 | +            new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = checkpoint.pop(old_prefix + "mod_y.bias") | 
|  | 2474 | + | 
|  | 2475 | +        # Visual attention | 
|  | 2476 | +        qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_x.weight") | 
|  | 2477 | +        q, k, v = qkv_weight.chunk(3, dim=0) | 
|  | 2478 | + | 
|  | 2479 | +        new_state_dict[block_prefix + "attn1.to_q.weight"] = q | 
|  | 2480 | +        new_state_dict[block_prefix + "attn1.to_k.weight"] = k | 
|  | 2481 | +        new_state_dict[block_prefix + "attn1.to_v.weight"] = v | 
|  | 2482 | +        new_state_dict[block_prefix + "attn1.norm_q.weight"] = checkpoint.pop(old_prefix + "attn.q_norm_x.weight") | 
|  | 2483 | +        new_state_dict[block_prefix + "attn1.norm_k.weight"] = checkpoint.pop(old_prefix + "attn.k_norm_x.weight") | 
|  | 2484 | +        new_state_dict[block_prefix + "attn1.to_out.0.weight"] = checkpoint.pop(old_prefix + "attn.proj_x.weight") | 
|  | 2485 | +        new_state_dict[block_prefix + "attn1.to_out.0.bias"] = checkpoint.pop(old_prefix + "attn.proj_x.bias") | 
|  | 2486 | + | 
|  | 2487 | +        # Context attention | 
|  | 2488 | +        qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_y.weight") | 
|  | 2489 | +        q, k, v = qkv_weight.chunk(3, dim=0) | 
|  | 2490 | + | 
|  | 2491 | +        new_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q | 
|  | 2492 | +        new_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k | 
|  | 2493 | +        new_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v | 
|  | 2494 | +        new_state_dict[block_prefix + "attn1.norm_added_q.weight"] = checkpoint.pop( | 
|  | 2495 | +            old_prefix + "attn.q_norm_y.weight" | 
|  | 2496 | +        ) | 
|  | 2497 | +        new_state_dict[block_prefix + "attn1.norm_added_k.weight"] = checkpoint.pop( | 
|  | 2498 | +            old_prefix + "attn.k_norm_y.weight" | 
|  | 2499 | +        ) | 
|  | 2500 | +        if i < num_layers - 1: | 
|  | 2501 | +            new_state_dict[block_prefix + "attn1.to_add_out.weight"] = checkpoint.pop( | 
|  | 2502 | +                old_prefix + "attn.proj_y.weight" | 
|  | 2503 | +            ) | 
|  | 2504 | +            new_state_dict[block_prefix + "attn1.to_add_out.bias"] = checkpoint.pop(old_prefix + "attn.proj_y.bias") | 
|  | 2505 | + | 
|  | 2506 | +        # MLP | 
|  | 2507 | +        new_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate( | 
|  | 2508 | +            checkpoint.pop(old_prefix + "mlp_x.w1.weight") | 
|  | 2509 | +        ) | 
|  | 2510 | +        new_state_dict[block_prefix + "ff.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_x.w2.weight") | 
|  | 2511 | +        if i < num_layers - 1: | 
|  | 2512 | +            new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate( | 
|  | 2513 | +                checkpoint.pop(old_prefix + "mlp_y.w1.weight") | 
|  | 2514 | +            ) | 
|  | 2515 | +            new_state_dict[block_prefix + "ff_context.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_y.w2.weight") | 
|  | 2516 | + | 
|  | 2517 | +    # Output layers | 
|  | 2518 | +    new_state_dict["norm_out.linear.weight"] = swap_scale_shift(checkpoint.pop("final_layer.mod.weight"), dim=0) | 
|  | 2519 | +    new_state_dict["norm_out.linear.bias"] = swap_scale_shift(checkpoint.pop("final_layer.mod.bias"), dim=0) | 
|  | 2520 | +    new_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") | 
|  | 2521 | +    new_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") | 
|  | 2522 | + | 
|  | 2523 | +    new_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies") | 
|  | 2524 | + | 
|  | 2525 | +    return new_state_dict | 
0 commit comments