Skip to content

Commit 76b2ea4

Browse files
committed
update
1 parent 128b96f commit 76b2ea4

File tree

3 files changed

+116
-1
lines changed

3 files changed

+116
-1
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
convert_ldm_vae_checkpoint,
3333
convert_ltx_transformer_checkpoint_to_diffusers,
3434
convert_ltx_vae_checkpoint_to_diffusers,
35+
convert_mochi_transformer_checkpoint_to_diffusers,
3536
convert_sd3_transformer_checkpoint_to_diffusers,
3637
convert_stable_cascade_unet_single_file_to_diffusers,
3738
create_controlnet_diffusers_config_from_ldm,
@@ -96,6 +97,10 @@
9697
"default_subfolder": "vae",
9798
},
9899
"AutoencoderDC": {"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers},
100+
"MochiTransformer3DModel": {
101+
"checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers,
102+
"default_subfolder": "transformer",
103+
},
99104
}
100105

101106

src/diffusers/loaders/single_file_utils.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@
106106
],
107107
"autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias",
108108
"autoencoder-dc-sana": "encoder.project_in.conv.bias",
109+
"mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"],
109110
}
110111

111112
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -157,6 +158,7 @@
157158
"autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"},
158159
"autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
159160
"autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"},
161+
"mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"},
160162
}
161163

162164
# Use to configure model sample size when original config is provided
@@ -610,6 +612,9 @@ def infer_diffusers_model_type(checkpoint):
610612
else:
611613
model_type = "autoencoder-dc-f128c512"
612614

615+
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["mochi-1-preview"]):
616+
model_type = "mochi-1-preview"
617+
613618
else:
614619
model_type = "v1"
615620

@@ -1750,6 +1755,12 @@ def swap_scale_shift(weight, dim):
17501755
return new_weight
17511756

17521757

1758+
def swap_proj_gate(weight):
1759+
proj, gate = weight.chunk(2, dim=0)
1760+
new_weight = torch.cat([gate, proj], dim=0)
1761+
return new_weight
1762+
1763+
17531764
def get_attn2_layers(state_dict):
17541765
attn2_layers = []
17551766
for key in state_dict.keys():
@@ -2406,3 +2417,101 @@ def remap_proj_conv_(key: str, state_dict):
24062417
handler_fn_inplace(key, converted_state_dict)
24072418

24082419
return converted_state_dict
2420+
2421+
2422+
def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
2423+
new_state_dict = {}
2424+
2425+
# Comfy checkpoints add this prefix
2426+
keys = list(checkpoint.keys())
2427+
for k in keys:
2428+
if "model.diffusion_model." in k:
2429+
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
2430+
2431+
# Convert patch_embed
2432+
new_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
2433+
new_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
2434+
2435+
# Convert time_embed
2436+
new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight")
2437+
new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
2438+
new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight")
2439+
new_state_dict["time_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
2440+
new_state_dict["time_embed.pooler.to_kv.weight"] = checkpoint.pop("t5_y_embedder.to_kv.weight")
2441+
new_state_dict["time_embed.pooler.to_kv.bias"] = checkpoint.pop("t5_y_embedder.to_kv.bias")
2442+
new_state_dict["time_embed.pooler.to_q.weight"] = checkpoint.pop("t5_y_embedder.to_q.weight")
2443+
new_state_dict["time_embed.pooler.to_q.bias"] = checkpoint.pop("t5_y_embedder.to_q.bias")
2444+
new_state_dict["time_embed.pooler.to_out.weight"] = checkpoint.pop("t5_y_embedder.to_out.weight")
2445+
new_state_dict["time_embed.pooler.to_out.bias"] = checkpoint.pop("t5_y_embedder.to_out.bias")
2446+
new_state_dict["time_embed.caption_proj.weight"] = checkpoint.pop("t5_yproj.weight")
2447+
new_state_dict["time_embed.caption_proj.bias"] = checkpoint.pop("t5_yproj.bias")
2448+
2449+
# Convert transformer blocks
2450+
num_layers = 48
2451+
for i in range(num_layers):
2452+
block_prefix = f"transformer_blocks.{i}."
2453+
old_prefix = f"blocks.{i}."
2454+
2455+
# norm1
2456+
new_state_dict[block_prefix + "norm1.linear.weight"] = checkpoint.pop(old_prefix + "mod_x.weight")
2457+
new_state_dict[block_prefix + "norm1.linear.bias"] = checkpoint.pop(old_prefix + "mod_x.bias")
2458+
if i < num_layers - 1:
2459+
new_state_dict[block_prefix + "norm1_context.linear.weight"] = checkpoint.pop(old_prefix + "mod_y.weight")
2460+
new_state_dict[block_prefix + "norm1_context.linear.bias"] = checkpoint.pop(old_prefix + "mod_y.bias")
2461+
else:
2462+
new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = checkpoint.pop(
2463+
old_prefix + "mod_y.weight"
2464+
)
2465+
new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = checkpoint.pop(old_prefix + "mod_y.bias")
2466+
2467+
# Visual attention
2468+
qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_x.weight")
2469+
q, k, v = qkv_weight.chunk(3, dim=0)
2470+
2471+
new_state_dict[block_prefix + "attn1.to_q.weight"] = q
2472+
new_state_dict[block_prefix + "attn1.to_k.weight"] = k
2473+
new_state_dict[block_prefix + "attn1.to_v.weight"] = v
2474+
new_state_dict[block_prefix + "attn1.norm_q.weight"] = checkpoint.pop(old_prefix + "attn.q_norm_x.weight")
2475+
new_state_dict[block_prefix + "attn1.norm_k.weight"] = checkpoint.pop(old_prefix + "attn.k_norm_x.weight")
2476+
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = checkpoint.pop(old_prefix + "attn.proj_x.weight")
2477+
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = checkpoint.pop(old_prefix + "attn.proj_x.bias")
2478+
2479+
# Context attention
2480+
qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_y.weight")
2481+
q, k, v = qkv_weight.chunk(3, dim=0)
2482+
2483+
new_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q
2484+
new_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k
2485+
new_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v
2486+
new_state_dict[block_prefix + "attn1.norm_added_q.weight"] = checkpoint.pop(
2487+
old_prefix + "attn.q_norm_y.weight"
2488+
)
2489+
new_state_dict[block_prefix + "attn1.norm_added_k.weight"] = checkpoint.pop(
2490+
old_prefix + "attn.k_norm_y.weight"
2491+
)
2492+
if i < num_layers - 1:
2493+
new_state_dict[block_prefix + "attn1.to_add_out.weight"] = checkpoint.pop(
2494+
old_prefix + "attn.proj_y.weight"
2495+
)
2496+
new_state_dict[block_prefix + "attn1.to_add_out.bias"] = checkpoint.pop(old_prefix + "attn.proj_y.bias")
2497+
2498+
# MLP
2499+
new_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate(
2500+
checkpoint.pop(old_prefix + "mlp_x.w1.weight")
2501+
)
2502+
new_state_dict[block_prefix + "ff.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_x.w2.weight")
2503+
if i < num_layers - 1:
2504+
new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate(
2505+
checkpoint.pop(old_prefix + "mlp_y.w1.weight")
2506+
)
2507+
new_state_dict[block_prefix + "ff_context.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_y.w2.weight")
2508+
2509+
# Output layers
2510+
new_state_dict["norm_out.linear.weight"] = swap_scale_shift(checkpoint.pop("final_layer.mod.weight"), dim=0)
2511+
new_state_dict["norm_out.linear.bias"] = swap_scale_shift(checkpoint.pop("final_layer.mod.bias"), dim=0)
2512+
new_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
2513+
new_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
2514+
2515+
new_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies")
2516+
2517+
return new_state_dict

src/diffusers/models/transformers/transformer_mochi.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from ...configuration_utils import ConfigMixin, register_to_config
2222
from ...loaders import PeftAdapterMixin
23+
from ...loaders.single_file_model import FromOriginalModelMixin
2324
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
2425
from ...utils.torch_utils import maybe_allow_in_graph
2526
from ..attention import FeedForward
@@ -304,7 +305,7 @@ def forward(
304305

305306

306307
@maybe_allow_in_graph
307-
class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
308+
class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
308309
r"""
309310
A Transformer model for video-like data introduced in [Mochi](https://huggingface.co/genmo/mochi-1-preview).
310311

0 commit comments

Comments
 (0)