Skip to content

Commit a7372bd

Browse files
committed
mochi transformer
1 parent 1d1e1a2 commit a7372bd

File tree

15 files changed

+1969
-9
lines changed

15 files changed

+1969
-9
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,8 @@
266266
title: LatteTransformer3DModel
267267
- local: api/models/lumina_nextdit2d
268268
title: LuminaNextDiT2DModel
269+
- local: api/models/mochi_transformer3d
270+
title: MochiTransformer3DModel
269271
- local: api/models/pixart_transformer2d
270272
title: PixArtTransformer2DModel
271273
- local: api/models/prior_transformer
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License. -->
11+
12+
# MochiTransformer3DModel
13+
14+
A Diffusion Transformer model for 3D video-like data was introduced in [Mochi-1 Preview](https://huggingface.co/genmo/mochi-1-preview) by Genmo.
15+
16+
The model can be loaded with the following code snippet.
17+
18+
```python
19+
from diffusers import MochiTransformer3DModel
20+
21+
vae = MochiTransformer3DModel.from_pretrained("genmo/mochi-1-preview", subfolder="transformer", torch_dtype=torch.float16).to("cuda")
22+
```
23+
24+
## MochiTransformer3DModel
25+
26+
[[autodoc]] MochiTransformer3DModel
27+
28+
## Transformer2DModelOutput
29+
30+
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
import argparse
2+
from contextlib import nullcontext
3+
4+
import torch
5+
from accelerate import init_empty_weights
6+
from safetensors.torch import load_file
7+
8+
# from transformers import T5EncoderModel, T5Tokenizer
9+
from diffusers import MochiTransformer3DModel
10+
from diffusers.utils.import_utils import is_accelerate_available
11+
12+
13+
CTX = init_empty_weights if is_accelerate_available else nullcontext
14+
15+
TOKENIZER_MAX_LENGTH = 256
16+
17+
parser = argparse.ArgumentParser()
18+
parser.add_argument("--transformer_checkpoint_path", default=None, type=str)
19+
# parser.add_argument("--vae_checkpoint_path", default=None, type=str)
20+
parser.add_argument("--output_path", required=True, type=str)
21+
parser.add_argument("--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving")
22+
parser.add_argument("--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory")
23+
parser.add_argument("--dtype", type=str, default=None)
24+
25+
args = parser.parse_args()
26+
27+
28+
# This is specific to `AdaLayerNormContinuous`:
29+
# Diffusers implementation split the linear projection into the scale, shift while Mochi split it into shift, scale
30+
def swap_scale_shift(weight, dim):
31+
shift, scale = weight.chunk(2, dim=0)
32+
new_weight = torch.cat([scale, shift], dim=0)
33+
return new_weight
34+
35+
36+
def convert_mochi_transformer_checkpoint_to_diffusers(ckpt_path):
37+
original_state_dict = load_file(ckpt_path, device="cpu")
38+
new_state_dict = {}
39+
40+
# Convert patch_embed
41+
new_state_dict["patch_embed.proj.weight"] = original_state_dict.pop("x_embedder.proj.weight")
42+
new_state_dict["patch_embed.proj.bias"] = original_state_dict.pop("x_embedder.proj.bias")
43+
44+
# Convert time_embed
45+
new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop("t_embedder.mlp.0.weight")
46+
new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop("t_embedder.mlp.0.bias")
47+
new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop("t_embedder.mlp.2.weight")
48+
new_state_dict["time_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop("t_embedder.mlp.2.bias")
49+
new_state_dict["time_embed.pooler.to_kv.weight"] = original_state_dict.pop("t5_y_embedder.to_kv.weight")
50+
new_state_dict["time_embed.pooler.to_kv.bias"] = original_state_dict.pop("t5_y_embedder.to_kv.bias")
51+
new_state_dict["time_embed.pooler.to_q.weight"] = original_state_dict.pop("t5_y_embedder.to_q.weight")
52+
new_state_dict["time_embed.pooler.to_q.bias"] = original_state_dict.pop("t5_y_embedder.to_q.bias")
53+
new_state_dict["time_embed.pooler.to_out.weight"] = original_state_dict.pop("t5_y_embedder.to_out.weight")
54+
new_state_dict["time_embed.pooler.to_out.bias"] = original_state_dict.pop("t5_y_embedder.to_out.bias")
55+
new_state_dict["time_embed.caption_proj.weight"] = original_state_dict.pop("t5_yproj.weight")
56+
new_state_dict["time_embed.caption_proj.bias"] = original_state_dict.pop("t5_yproj.bias")
57+
58+
# Convert transformer blocks
59+
num_layers = 48
60+
for i in range(num_layers):
61+
block_prefix = f"transformer_blocks.{i}."
62+
old_prefix = f"blocks.{i}."
63+
64+
# norm1
65+
new_state_dict[block_prefix + "norm1.linear.weight"] = original_state_dict.pop(old_prefix + "mod_x.weight")
66+
new_state_dict[block_prefix + "norm1.linear.bias"] = original_state_dict.pop(old_prefix + "mod_x.bias")
67+
if i < num_layers - 1:
68+
new_state_dict[block_prefix + "norm1_context.linear.weight"] = original_state_dict.pop(
69+
old_prefix + "mod_y.weight"
70+
)
71+
new_state_dict[block_prefix + "norm1_context.linear.bias"] = original_state_dict.pop(
72+
old_prefix + "mod_y.bias"
73+
)
74+
else:
75+
new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = original_state_dict.pop(
76+
old_prefix + "mod_y.weight"
77+
)
78+
new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = original_state_dict.pop(
79+
old_prefix + "mod_y.bias"
80+
)
81+
82+
# Visual attention
83+
qkv_weight = original_state_dict.pop(old_prefix + "attn.qkv_x.weight")
84+
q, k, v = qkv_weight.chunk(3, dim=0)
85+
86+
new_state_dict[block_prefix + "attn1.to_q.weight"] = q
87+
new_state_dict[block_prefix + "attn1.to_k.weight"] = k
88+
new_state_dict[block_prefix + "attn1.to_v.weight"] = v
89+
new_state_dict[block_prefix + "attn1.norm_q.weight"] = original_state_dict.pop(
90+
old_prefix + "attn.q_norm_x.weight"
91+
)
92+
new_state_dict[block_prefix + "attn1.norm_k.weight"] = original_state_dict.pop(
93+
old_prefix + "attn.k_norm_x.weight"
94+
)
95+
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = original_state_dict.pop(
96+
old_prefix + "attn.proj_x.weight"
97+
)
98+
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = original_state_dict.pop(old_prefix + "attn.proj_x.bias")
99+
100+
# Context attention
101+
qkv_weight = original_state_dict.pop(old_prefix + "attn.qkv_y.weight")
102+
q, k, v = qkv_weight.chunk(3, dim=0)
103+
104+
new_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q
105+
new_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k
106+
new_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v
107+
new_state_dict[block_prefix + "attn1.norm_added_q.weight"] = original_state_dict.pop(
108+
old_prefix + "attn.q_norm_y.weight"
109+
)
110+
new_state_dict[block_prefix + "attn1.norm_added_k.weight"] = original_state_dict.pop(
111+
old_prefix + "attn.k_norm_y.weight"
112+
)
113+
if i < num_layers - 1:
114+
new_state_dict[block_prefix + "attn1.to_add_out.weight"] = original_state_dict.pop(
115+
old_prefix + "attn.proj_y.weight"
116+
)
117+
new_state_dict[block_prefix + "attn1.to_add_out.bias"] = original_state_dict.pop(
118+
old_prefix + "attn.proj_y.bias"
119+
)
120+
121+
# MLP
122+
new_state_dict[block_prefix + "ff.net.0.proj.weight"] = original_state_dict.pop(old_prefix + "mlp_x.w1.weight")
123+
new_state_dict[block_prefix + "ff.net.2.weight"] = original_state_dict.pop(old_prefix + "mlp_x.w2.weight")
124+
if i < num_layers - 1:
125+
new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = original_state_dict.pop(
126+
old_prefix + "mlp_y.w1.weight"
127+
)
128+
new_state_dict[block_prefix + "ff_context.net.2.weight"] = original_state_dict.pop(
129+
old_prefix + "mlp_y.w2.weight"
130+
)
131+
132+
# Output layers
133+
new_state_dict["norm_out.linear.weight"] = swap_scale_shift(original_state_dict.pop("final_layer.mod.weight"), dim=0)
134+
new_state_dict["norm_out.linear.bias"] = swap_scale_shift(original_state_dict.pop("final_layer.mod.bias"), dim=0)
135+
new_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight")
136+
new_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias")
137+
138+
new_state_dict["pos_frequencies"] = original_state_dict.pop("pos_frequencies")
139+
140+
print("Remaining Keys:", original_state_dict.keys())
141+
142+
return new_state_dict
143+
144+
145+
# def convert_mochi_vae_checkpoint_to_diffusers(ckpt_path, vae_config):
146+
# original_state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
147+
# return convert_ldm_vae_checkpoint(original_state_dict, vae_config)
148+
149+
150+
def main(args):
151+
if args.dtype is None:
152+
dtype = None
153+
if args.dtype == "fp16":
154+
dtype = torch.float16
155+
elif args.dtype == "bf16":
156+
dtype = torch.bfloat16
157+
elif args.dtype == "fp32":
158+
dtype = torch.float32
159+
else:
160+
raise ValueError(f"Unsupported dtype: {args.dtype}")
161+
162+
transformer = None
163+
# vae = None
164+
165+
if args.transformer_checkpoint_path is not None:
166+
converted_transformer_state_dict = convert_mochi_transformer_checkpoint_to_diffusers(
167+
args.transformer_checkpoint_path
168+
)
169+
transformer = MochiTransformer3DModel()
170+
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
171+
if dtype is not None:
172+
# Original checkpoint data type will be preserved
173+
transformer = transformer.to(dtype=dtype)
174+
175+
# text_encoder_id = "google/t5-v1_1-xxl"
176+
# tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
177+
# text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
178+
179+
# # Apparently, the conversion does not work anymore without this :shrug:
180+
# for param in text_encoder.parameters():
181+
# param.data = param.data.contiguous()
182+
183+
transformer.save_pretrained("/raid/aryan/mochi-diffusers", subfolder="transformer")
184+
185+
186+
if __name__ == "__main__":
187+
main(args)

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@
100100
"Kandinsky3UNet",
101101
"LatteTransformer3DModel",
102102
"LuminaNextDiT2DModel",
103+
"MochiTransformer3DModel",
103104
"ModelMixin",
104105
"MotionAdapter",
105106
"MultiAdapter",
@@ -579,6 +580,7 @@
579580
Kandinsky3UNet,
580581
LatteTransformer3DModel,
581582
LuminaNextDiT2DModel,
583+
MochiTransformer3DModel,
582584
ModelMixin,
583585
MotionAdapter,
584586
MultiAdapter,

src/diffusers/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
5757
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
5858
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
59+
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
5960
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
6061
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
6162
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
@@ -106,6 +107,7 @@
106107
HunyuanDiT2DModel,
107108
LatteTransformer3DModel,
108109
LuminaNextDiT2DModel,
110+
MochiTransformer3DModel,
109111
PixArtTransformer2DModel,
110112
PriorTransformer,
111113
SD3Transformer2DModel,

src/diffusers/models/activations.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,14 +134,18 @@ class SwiGLU(nn.Module):
134134
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
135135
"""
136136

137-
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
137+
def __init__(self, dim_in: int, dim_out: int, bias: bool = True, flip_gate: bool = False):
138138
super().__init__()
139+
self.flip_gate = flip_gate
140+
139141
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
140142
self.activation = nn.SiLU()
141143

142144
def forward(self, hidden_states):
143145
hidden_states = self.proj(hidden_states)
144146
hidden_states, gate = hidden_states.chunk(2, dim=-1)
147+
if self.flip_gate:
148+
hidden_states, gate = gate, hidden_states
145149
return hidden_states * self.activation(gate)
146150

147151

src/diffusers/models/attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1206,6 +1206,7 @@ def __init__(
12061206
final_dropout: bool = False,
12071207
inner_dim=None,
12081208
bias: bool = True,
1209+
flip_gate: bool = False,
12091210
):
12101211
super().__init__()
12111212
if inner_dim is None:
@@ -1221,7 +1222,7 @@ def __init__(
12211222
elif activation_fn == "geglu-approximate":
12221223
act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
12231224
elif activation_fn == "swiglu":
1224-
act_fn = SwiGLU(dim, inner_dim, bias=bias)
1225+
act_fn = SwiGLU(dim, inner_dim, bias=bias, flip_gate=flip_gate)
12251226

12261227
self.net = nn.ModuleList([])
12271228
# project in

0 commit comments

Comments
 (0)