Skip to content

Commit c12ce7d

Browse files
committed
Merge branch 'mochi-t2v' into mochi-t2v-pipeline
2 parents 275041d + 46f95d5 commit c12ce7d

File tree

10 files changed

+924
-267
lines changed

10 files changed

+924
-267
lines changed
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
import argparse
2+
from contextlib import nullcontext
3+
4+
import torch
5+
from accelerate import init_empty_weights
6+
from safetensors.torch import load_file
7+
8+
# from transformers import T5EncoderModel, T5Tokenizer
9+
from diffusers import MochiTransformer3DModel
10+
from diffusers.utils.import_utils import is_accelerate_available
11+
12+
13+
CTX = init_empty_weights if is_accelerate_available else nullcontext
14+
15+
TOKENIZER_MAX_LENGTH = 256
16+
17+
parser = argparse.ArgumentParser()
18+
parser.add_argument("--transformer_checkpoint_path", default=None, type=str)
19+
# parser.add_argument("--vae_checkpoint_path", default=None, type=str)
20+
parser.add_argument("--output_path", required=True, type=str)
21+
parser.add_argument("--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving")
22+
parser.add_argument("--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory")
23+
parser.add_argument("--dtype", type=str, default=None)
24+
25+
args = parser.parse_args()
26+
27+
28+
# This is specific to `AdaLayerNormContinuous`:
29+
# Diffusers implementation split the linear projection into the scale, shift while Mochi split it into shift, scale
30+
def swap_scale_shift(weight, dim):
31+
shift, scale = weight.chunk(2, dim=0)
32+
new_weight = torch.cat([scale, shift], dim=0)
33+
return new_weight
34+
35+
36+
def convert_mochi_transformer_checkpoint_to_diffusers(ckpt_path):
37+
original_state_dict = load_file(ckpt_path, device="cpu")
38+
new_state_dict = {}
39+
40+
# Convert patch_embed
41+
new_state_dict["patch_embed.proj.weight"] = original_state_dict.pop("x_embedder.proj.weight")
42+
new_state_dict["patch_embed.proj.bias"] = original_state_dict.pop("x_embedder.proj.bias")
43+
44+
# Convert time_embed
45+
new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop("t_embedder.mlp.0.weight")
46+
new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop("t_embedder.mlp.0.bias")
47+
new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop("t_embedder.mlp.2.weight")
48+
new_state_dict["time_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop("t_embedder.mlp.2.bias")
49+
new_state_dict["time_embed.pooler.to_kv.weight"] = original_state_dict.pop("t5_y_embedder.to_kv.weight")
50+
new_state_dict["time_embed.pooler.to_kv.bias"] = original_state_dict.pop("t5_y_embedder.to_kv.bias")
51+
new_state_dict["time_embed.pooler.to_q.weight"] = original_state_dict.pop("t5_y_embedder.to_q.weight")
52+
new_state_dict["time_embed.pooler.to_q.bias"] = original_state_dict.pop("t5_y_embedder.to_q.bias")
53+
new_state_dict["time_embed.pooler.to_out.weight"] = original_state_dict.pop("t5_y_embedder.to_out.weight")
54+
new_state_dict["time_embed.pooler.to_out.bias"] = original_state_dict.pop("t5_y_embedder.to_out.bias")
55+
new_state_dict["time_embed.caption_proj.weight"] = original_state_dict.pop("t5_yproj.weight")
56+
new_state_dict["time_embed.caption_proj.bias"] = original_state_dict.pop("t5_yproj.bias")
57+
58+
# Convert transformer blocks
59+
num_layers = 48
60+
for i in range(num_layers):
61+
block_prefix = f"transformer_blocks.{i}."
62+
old_prefix = f"blocks.{i}."
63+
64+
# norm1
65+
new_state_dict[block_prefix + "norm1.linear.weight"] = original_state_dict.pop(old_prefix + "mod_x.weight")
66+
new_state_dict[block_prefix + "norm1.linear.bias"] = original_state_dict.pop(old_prefix + "mod_x.bias")
67+
if i < num_layers - 1:
68+
new_state_dict[block_prefix + "norm1_context.linear.weight"] = original_state_dict.pop(
69+
old_prefix + "mod_y.weight"
70+
)
71+
new_state_dict[block_prefix + "norm1_context.linear.bias"] = original_state_dict.pop(
72+
old_prefix + "mod_y.bias"
73+
)
74+
else:
75+
new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = original_state_dict.pop(
76+
old_prefix + "mod_y.weight"
77+
)
78+
new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = original_state_dict.pop(
79+
old_prefix + "mod_y.bias"
80+
)
81+
82+
# Visual attention
83+
qkv_weight = original_state_dict.pop(old_prefix + "attn.qkv_x.weight")
84+
q, k, v = qkv_weight.chunk(3, dim=0)
85+
86+
new_state_dict[block_prefix + "attn1.to_q.weight"] = q
87+
new_state_dict[block_prefix + "attn1.to_k.weight"] = k
88+
new_state_dict[block_prefix + "attn1.to_v.weight"] = v
89+
new_state_dict[block_prefix + "attn1.norm_q.weight"] = original_state_dict.pop(
90+
old_prefix + "attn.q_norm_x.weight"
91+
)
92+
new_state_dict[block_prefix + "attn1.norm_k.weight"] = original_state_dict.pop(
93+
old_prefix + "attn.k_norm_x.weight"
94+
)
95+
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = original_state_dict.pop(
96+
old_prefix + "attn.proj_x.weight"
97+
)
98+
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = original_state_dict.pop(old_prefix + "attn.proj_x.bias")
99+
100+
# Context attention
101+
qkv_weight = original_state_dict.pop(old_prefix + "attn.qkv_y.weight")
102+
q, k, v = qkv_weight.chunk(3, dim=0)
103+
104+
new_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q
105+
new_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k
106+
new_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v
107+
new_state_dict[block_prefix + "attn1.norm_added_q.weight"] = original_state_dict.pop(
108+
old_prefix + "attn.q_norm_y.weight"
109+
)
110+
new_state_dict[block_prefix + "attn1.norm_added_k.weight"] = original_state_dict.pop(
111+
old_prefix + "attn.k_norm_y.weight"
112+
)
113+
if i < num_layers - 1:
114+
new_state_dict[block_prefix + "attn1.to_add_out.weight"] = original_state_dict.pop(
115+
old_prefix + "attn.proj_y.weight"
116+
)
117+
new_state_dict[block_prefix + "attn1.to_add_out.bias"] = original_state_dict.pop(
118+
old_prefix + "attn.proj_y.bias"
119+
)
120+
121+
# MLP
122+
new_state_dict[block_prefix + "ff.net.0.proj.weight"] = original_state_dict.pop(old_prefix + "mlp_x.w1.weight")
123+
new_state_dict[block_prefix + "ff.net.2.weight"] = original_state_dict.pop(old_prefix + "mlp_x.w2.weight")
124+
if i < num_layers - 1:
125+
new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = original_state_dict.pop(
126+
old_prefix + "mlp_y.w1.weight"
127+
)
128+
new_state_dict[block_prefix + "ff_context.net.2.weight"] = original_state_dict.pop(
129+
old_prefix + "mlp_y.w2.weight"
130+
)
131+
132+
# Output layers
133+
new_state_dict["norm_out.linear.weight"] = original_state_dict.pop("final_layer.mod.weight")
134+
new_state_dict["norm_out.linear.bias"] = original_state_dict.pop("final_layer.mod.bias")
135+
new_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight")
136+
new_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias")
137+
138+
new_state_dict["pos_frequencies"] = original_state_dict.pop("pos_frequencies")
139+
140+
print("Remaining Keys:", original_state_dict.keys())
141+
142+
return new_state_dict
143+
144+
145+
# def convert_mochi_vae_checkpoint_to_diffusers(ckpt_path, vae_config):
146+
# original_state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
147+
# return convert_ldm_vae_checkpoint(original_state_dict, vae_config)
148+
149+
150+
def main(args):
151+
if args.dtype is None:
152+
dtype = None
153+
if args.dtype == "fp16":
154+
dtype = torch.float16
155+
elif args.dtype == "bf16":
156+
dtype = torch.bfloat16
157+
elif args.dtype == "fp32":
158+
dtype = torch.float32
159+
else:
160+
raise ValueError(f"Unsupported dtype: {args.dtype}")
161+
162+
transformer = None
163+
# vae = None
164+
165+
if args.transformer_checkpoint_path is not None:
166+
converted_transformer_state_dict = convert_mochi_transformer_checkpoint_to_diffusers(
167+
args.transformer_checkpoint_path
168+
)
169+
transformer = MochiTransformer3DModel()
170+
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
171+
if dtype is not None:
172+
# Original checkpoint data type will be preserved
173+
transformer = transformer.to(dtype=dtype)
174+
175+
# text_encoder_id = "google/t5-v1_1-xxl"
176+
# tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
177+
# text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
178+
179+
# # Apparently, the conversion does not work anymore without this :shrug:
180+
# for param in text_encoder.parameters():
181+
# param.data = param.data.contiguous()
182+
183+
transformer.save_pretrained("/raid/aryan/mochi-diffusers", subfolder="transformer")
184+
185+
186+
if __name__ == "__main__":
187+
main(args)

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@
100100
"Kandinsky3UNet",
101101
"LatteTransformer3DModel",
102102
"LuminaNextDiT2DModel",
103+
"MochiTransformer3DModel",
103104
"ModelMixin",
104105
"MotionAdapter",
105106
"MultiAdapter",
@@ -579,6 +580,7 @@
579580
Kandinsky3UNet,
580581
LatteTransformer3DModel,
581582
LuminaNextDiT2DModel,
583+
MochiTransformer3DModel,
582584
ModelMixin,
583585
MotionAdapter,
584586
MultiAdapter,

src/diffusers/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
5757
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
5858
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
59+
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
5960
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
6061
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
6162
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
@@ -106,6 +107,7 @@
106107
HunyuanDiT2DModel,
107108
LatteTransformer3DModel,
108109
LuminaNextDiT2DModel,
110+
MochiTransformer3DModel,
109111
PixArtTransformer2DModel,
110112
PriorTransformer,
111113
SD3Transformer2DModel,

src/diffusers/models/attention_processor.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def __init__(
120120
_from_deprecated_attn_block: bool = False,
121121
processor: Optional["AttnProcessor"] = None,
122122
out_dim: int = None,
123+
out_context_dim: int = None,
123124
context_pre_only=None,
124125
pre_only=False,
125126
elementwise_affine: bool = True,
@@ -142,6 +143,7 @@ def __init__(
142143
self.dropout = dropout
143144
self.fused_projections = False
144145
self.out_dim = out_dim if out_dim is not None else query_dim
146+
self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
145147
self.context_pre_only = context_pre_only
146148
self.pre_only = pre_only
147149

@@ -241,7 +243,7 @@ def __init__(
241243
self.to_out.append(nn.Dropout(dropout))
242244

243245
if self.context_pre_only is not None and not self.context_pre_only:
244-
self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
246+
self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
245247

246248
if qk_norm is not None and added_kv_proj_dim is not None:
247249
if qk_norm == "fp32_layer_norm":
@@ -1792,6 +1794,7 @@ def __call__(
17921794
hidden_states = attn.to_out[0](hidden_states)
17931795
# dropout
17941796
hidden_states = attn.to_out[1](hidden_states)
1797+
17951798
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
17961799

17971800
return hidden_states, encoder_hidden_states
@@ -3078,6 +3081,93 @@ def __call__(
30783081
return hidden_states
30793082

30803083

3084+
class MochiAttnProcessor2_0:
3085+
"""Attention processor used in Mochi."""
3086+
3087+
def __init__(self):
3088+
if not hasattr(F, "scaled_dot_product_attention"):
3089+
raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
3090+
3091+
def __call__(
3092+
self,
3093+
attn: Attention,
3094+
hidden_states: torch.Tensor,
3095+
encoder_hidden_states: torch.Tensor,
3096+
attention_mask: Optional[torch.Tensor] = None,
3097+
image_rotary_emb: Optional[torch.Tensor] = None,
3098+
) -> torch.Tensor:
3099+
query = attn.to_q(hidden_states)
3100+
key = attn.to_k(hidden_states)
3101+
value = attn.to_v(hidden_states)
3102+
3103+
query = query.unflatten(2, (attn.heads, -1))
3104+
key = key.unflatten(2, (attn.heads, -1))
3105+
value = value.unflatten(2, (attn.heads, -1))
3106+
3107+
if attn.norm_q is not None:
3108+
query = attn.norm_q(query)
3109+
if attn.norm_k is not None:
3110+
key = attn.norm_k(key)
3111+
3112+
encoder_query = attn.add_q_proj(encoder_hidden_states)
3113+
encoder_key = attn.add_k_proj(encoder_hidden_states)
3114+
encoder_value = attn.add_v_proj(encoder_hidden_states)
3115+
3116+
encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
3117+
encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
3118+
encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
3119+
3120+
if attn.norm_added_q is not None:
3121+
encoder_query = attn.norm_added_q(encoder_query)
3122+
if attn.norm_added_k is not None:
3123+
encoder_key = attn.norm_added_k(encoder_key)
3124+
3125+
if image_rotary_emb is not None:
3126+
3127+
def apply_rotary_emb(x, freqs_cos, freqs_sin):
3128+
x_even = x[..., 0::2].float()
3129+
x_odd = x[..., 1::2].float()
3130+
3131+
cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype)
3132+
sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype)
3133+
3134+
return torch.stack([cos, sin], dim=-1).flatten(-2)
3135+
3136+
query = apply_rotary_emb(query, *image_rotary_emb)
3137+
key = apply_rotary_emb(key, *image_rotary_emb)
3138+
3139+
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
3140+
encoder_query, encoder_key, encoder_value = (
3141+
encoder_query.transpose(1, 2),
3142+
encoder_key.transpose(1, 2),
3143+
encoder_value.transpose(1, 2),
3144+
)
3145+
3146+
sequence_length = query.size(2)
3147+
encoder_sequence_length = encoder_query.size(2)
3148+
3149+
query = torch.cat([query, encoder_query], dim=2)
3150+
key = torch.cat([key, encoder_key], dim=2)
3151+
value = torch.cat([value, encoder_value], dim=2)
3152+
3153+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
3154+
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
3155+
hidden_states = hidden_states.to(query.dtype)
3156+
3157+
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
3158+
(sequence_length, encoder_sequence_length), dim=1
3159+
)
3160+
3161+
# linear proj
3162+
hidden_states = attn.to_out[0](hidden_states)
3163+
# dropout
3164+
hidden_states = attn.to_out[1](hidden_states)
3165+
3166+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
3167+
3168+
return hidden_states, encoder_hidden_states
3169+
3170+
30813171
class FusedAttnProcessor2_0:
30823172
r"""
30833173
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses

0 commit comments

Comments
 (0)