|  | 
| 106 | 106 |     ], | 
| 107 | 107 |     "autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias", | 
| 108 | 108 |     "autoencoder-dc-sana": "encoder.project_in.conv.bias", | 
|  | 109 | +    "mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"], | 
| 109 | 110 | } | 
| 110 | 111 | 
 | 
| 111 | 112 | DIFFUSERS_DEFAULT_PIPELINE_PATHS = { | 
|  | 
| 157 | 158 |     "autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"}, | 
| 158 | 159 |     "autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"}, | 
| 159 | 160 |     "autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"}, | 
|  | 161 | +    "mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"}, | 
| 160 | 162 | } | 
| 161 | 163 | 
 | 
| 162 | 164 | # Use to configure model sample size when original config is provided | 
| @@ -610,6 +612,9 @@ def infer_diffusers_model_type(checkpoint): | 
| 610 | 612 |         else: | 
| 611 | 613 |             model_type = "autoencoder-dc-f128c512" | 
| 612 | 614 | 
 | 
|  | 615 | +    elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["mochi-1-preview"]): | 
|  | 616 | +        model_type = "mochi-1-preview" | 
|  | 617 | + | 
| 613 | 618 |     else: | 
| 614 | 619 |         model_type = "v1" | 
| 615 | 620 | 
 | 
| @@ -1750,6 +1755,12 @@ def swap_scale_shift(weight, dim): | 
| 1750 | 1755 |     return new_weight | 
| 1751 | 1756 | 
 | 
| 1752 | 1757 | 
 | 
|  | 1758 | +def swap_proj_gate(weight): | 
|  | 1759 | +    proj, gate = weight.chunk(2, dim=0) | 
|  | 1760 | +    new_weight = torch.cat([gate, proj], dim=0) | 
|  | 1761 | +    return new_weight | 
|  | 1762 | + | 
|  | 1763 | + | 
| 1753 | 1764 | def get_attn2_layers(state_dict): | 
| 1754 | 1765 |     attn2_layers = [] | 
| 1755 | 1766 |     for key in state_dict.keys(): | 
| @@ -2406,3 +2417,101 @@ def remap_proj_conv_(key: str, state_dict): | 
| 2406 | 2417 |             handler_fn_inplace(key, converted_state_dict) | 
| 2407 | 2418 | 
 | 
| 2408 | 2419 |     return converted_state_dict | 
|  | 2420 | + | 
|  | 2421 | + | 
|  | 2422 | +def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): | 
|  | 2423 | +    new_state_dict = {} | 
|  | 2424 | + | 
|  | 2425 | +    # Comfy checkpoints add this prefix | 
|  | 2426 | +    keys = list(checkpoint.keys()) | 
|  | 2427 | +    for k in keys: | 
|  | 2428 | +        if "model.diffusion_model." in k: | 
|  | 2429 | +            checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) | 
|  | 2430 | + | 
|  | 2431 | +    # Convert patch_embed | 
|  | 2432 | +    new_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight") | 
|  | 2433 | +    new_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias") | 
|  | 2434 | + | 
|  | 2435 | +    # Convert time_embed | 
|  | 2436 | +    new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight") | 
|  | 2437 | +    new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias") | 
|  | 2438 | +    new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight") | 
|  | 2439 | +    new_state_dict["time_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias") | 
|  | 2440 | +    new_state_dict["time_embed.pooler.to_kv.weight"] = checkpoint.pop("t5_y_embedder.to_kv.weight") | 
|  | 2441 | +    new_state_dict["time_embed.pooler.to_kv.bias"] = checkpoint.pop("t5_y_embedder.to_kv.bias") | 
|  | 2442 | +    new_state_dict["time_embed.pooler.to_q.weight"] = checkpoint.pop("t5_y_embedder.to_q.weight") | 
|  | 2443 | +    new_state_dict["time_embed.pooler.to_q.bias"] = checkpoint.pop("t5_y_embedder.to_q.bias") | 
|  | 2444 | +    new_state_dict["time_embed.pooler.to_out.weight"] = checkpoint.pop("t5_y_embedder.to_out.weight") | 
|  | 2445 | +    new_state_dict["time_embed.pooler.to_out.bias"] = checkpoint.pop("t5_y_embedder.to_out.bias") | 
|  | 2446 | +    new_state_dict["time_embed.caption_proj.weight"] = checkpoint.pop("t5_yproj.weight") | 
|  | 2447 | +    new_state_dict["time_embed.caption_proj.bias"] = checkpoint.pop("t5_yproj.bias") | 
|  | 2448 | + | 
|  | 2449 | +    # Convert transformer blocks | 
|  | 2450 | +    num_layers = 48 | 
|  | 2451 | +    for i in range(num_layers): | 
|  | 2452 | +        block_prefix = f"transformer_blocks.{i}." | 
|  | 2453 | +        old_prefix = f"blocks.{i}." | 
|  | 2454 | + | 
|  | 2455 | +        # norm1 | 
|  | 2456 | +        new_state_dict[block_prefix + "norm1.linear.weight"] = checkpoint.pop(old_prefix + "mod_x.weight") | 
|  | 2457 | +        new_state_dict[block_prefix + "norm1.linear.bias"] = checkpoint.pop(old_prefix + "mod_x.bias") | 
|  | 2458 | +        if i < num_layers - 1: | 
|  | 2459 | +            new_state_dict[block_prefix + "norm1_context.linear.weight"] = checkpoint.pop(old_prefix + "mod_y.weight") | 
|  | 2460 | +            new_state_dict[block_prefix + "norm1_context.linear.bias"] = checkpoint.pop(old_prefix + "mod_y.bias") | 
|  | 2461 | +        else: | 
|  | 2462 | +            new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = checkpoint.pop( | 
|  | 2463 | +                old_prefix + "mod_y.weight" | 
|  | 2464 | +            ) | 
|  | 2465 | +            new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = checkpoint.pop(old_prefix + "mod_y.bias") | 
|  | 2466 | + | 
|  | 2467 | +        # Visual attention | 
|  | 2468 | +        qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_x.weight") | 
|  | 2469 | +        q, k, v = qkv_weight.chunk(3, dim=0) | 
|  | 2470 | + | 
|  | 2471 | +        new_state_dict[block_prefix + "attn1.to_q.weight"] = q | 
|  | 2472 | +        new_state_dict[block_prefix + "attn1.to_k.weight"] = k | 
|  | 2473 | +        new_state_dict[block_prefix + "attn1.to_v.weight"] = v | 
|  | 2474 | +        new_state_dict[block_prefix + "attn1.norm_q.weight"] = checkpoint.pop(old_prefix + "attn.q_norm_x.weight") | 
|  | 2475 | +        new_state_dict[block_prefix + "attn1.norm_k.weight"] = checkpoint.pop(old_prefix + "attn.k_norm_x.weight") | 
|  | 2476 | +        new_state_dict[block_prefix + "attn1.to_out.0.weight"] = checkpoint.pop(old_prefix + "attn.proj_x.weight") | 
|  | 2477 | +        new_state_dict[block_prefix + "attn1.to_out.0.bias"] = checkpoint.pop(old_prefix + "attn.proj_x.bias") | 
|  | 2478 | + | 
|  | 2479 | +        # Context attention | 
|  | 2480 | +        qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_y.weight") | 
|  | 2481 | +        q, k, v = qkv_weight.chunk(3, dim=0) | 
|  | 2482 | + | 
|  | 2483 | +        new_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q | 
|  | 2484 | +        new_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k | 
|  | 2485 | +        new_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v | 
|  | 2486 | +        new_state_dict[block_prefix + "attn1.norm_added_q.weight"] = checkpoint.pop( | 
|  | 2487 | +            old_prefix + "attn.q_norm_y.weight" | 
|  | 2488 | +        ) | 
|  | 2489 | +        new_state_dict[block_prefix + "attn1.norm_added_k.weight"] = checkpoint.pop( | 
|  | 2490 | +            old_prefix + "attn.k_norm_y.weight" | 
|  | 2491 | +        ) | 
|  | 2492 | +        if i < num_layers - 1: | 
|  | 2493 | +            new_state_dict[block_prefix + "attn1.to_add_out.weight"] = checkpoint.pop( | 
|  | 2494 | +                old_prefix + "attn.proj_y.weight" | 
|  | 2495 | +            ) | 
|  | 2496 | +            new_state_dict[block_prefix + "attn1.to_add_out.bias"] = checkpoint.pop(old_prefix + "attn.proj_y.bias") | 
|  | 2497 | + | 
|  | 2498 | +        # MLP | 
|  | 2499 | +        new_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate( | 
|  | 2500 | +            checkpoint.pop(old_prefix + "mlp_x.w1.weight") | 
|  | 2501 | +        ) | 
|  | 2502 | +        new_state_dict[block_prefix + "ff.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_x.w2.weight") | 
|  | 2503 | +        if i < num_layers - 1: | 
|  | 2504 | +            new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate( | 
|  | 2505 | +                checkpoint.pop(old_prefix + "mlp_y.w1.weight") | 
|  | 2506 | +            ) | 
|  | 2507 | +            new_state_dict[block_prefix + "ff_context.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_y.w2.weight") | 
|  | 2508 | + | 
|  | 2509 | +    # Output layers | 
|  | 2510 | +    new_state_dict["norm_out.linear.weight"] = swap_scale_shift(checkpoint.pop("final_layer.mod.weight"), dim=0) | 
|  | 2511 | +    new_state_dict["norm_out.linear.bias"] = swap_scale_shift(checkpoint.pop("final_layer.mod.bias"), dim=0) | 
|  | 2512 | +    new_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") | 
|  | 2513 | +    new_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") | 
|  | 2514 | + | 
|  | 2515 | +    new_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies") | 
|  | 2516 | + | 
|  | 2517 | +    return new_state_dict | 
0 commit comments