Skip to content

Commit c2a1557

Browse files
committed
add conversion script
1 parent 0e9e281 commit c2a1557

File tree

7 files changed

+213
-5
lines changed

7 files changed

+213
-5
lines changed
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
import argparse
2+
from contextlib import nullcontext
3+
4+
import torch
5+
from accelerate import init_empty_weights
6+
from safetensors.torch import load_file
7+
# from transformers import T5EncoderModel, T5Tokenizer
8+
9+
from diffusers import MochiTransformer3DModel
10+
from diffusers.utils.import_utils import is_accelerate_available
11+
12+
13+
CTX = init_empty_weights if is_accelerate_available else nullcontext
14+
15+
TOKENIZER_MAX_LENGTH = 224
16+
17+
parser = argparse.ArgumentParser()
18+
parser.add_argument("--transformer_checkpoint_path", default=None, type=str)
19+
# parser.add_argument("--vae_checkpoint_path", default=None, type=str)
20+
parser.add_argument("--output_path", required=True, type=str)
21+
parser.add_argument("--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving")
22+
parser.add_argument("--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory")
23+
parser.add_argument("--dtype", type=str, default=None)
24+
25+
args = parser.parse_args()
26+
27+
28+
# This is specific to `AdaLayerNormContinuous`:
29+
# Diffusers implementation split the linear projection into the scale, shift while Mochi split it into shift, scale
30+
def swap_scale_shift(weight, dim):
31+
shift, scale = weight.chunk(2, dim=0)
32+
new_weight = torch.cat([scale, shift], dim=0)
33+
return new_weight
34+
35+
36+
def convert_mochi_transformer_checkpoint_to_diffusers(ckpt_path):
37+
original_state_dict = load_file(ckpt_path, device="cpu")
38+
new_state_dict = {}
39+
40+
# Convert patch_embed
41+
new_state_dict["patch_embed.proj.weight"] = original_state_dict.pop("x_embedder.proj.weight")
42+
new_state_dict["patch_embed.proj.bias"] = original_state_dict.pop("x_embedder.proj.bias")
43+
44+
# Convert time_embed
45+
new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop("t_embedder.mlp.0.weight")
46+
new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop("t_embedder.mlp.0.bias")
47+
new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop("t_embedder.mlp.2.weight")
48+
new_state_dict["time_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop("t_embedder.mlp.2.bias")
49+
new_state_dict["time_embed.pooler.to_kv.weight"] = original_state_dict.pop("t5_y_embedder.to_kv.weight")
50+
new_state_dict["time_embed.pooler.to_kv.bias"] = original_state_dict.pop("t5_y_embedder.to_kv.bias")
51+
new_state_dict["time_embed.pooler.to_q.weight"] = original_state_dict.pop("t5_y_embedder.to_q.weight")
52+
new_state_dict["time_embed.pooler.to_q.bias"] = original_state_dict.pop("t5_y_embedder.to_q.bias")
53+
new_state_dict["time_embed.pooler.to_out.weight"] = original_state_dict.pop("t5_y_embedder.to_out.weight")
54+
new_state_dict["time_embed.pooler.to_out.bias"] = original_state_dict.pop("t5_y_embedder.to_out.bias")
55+
new_state_dict["time_embed.caption_proj.weight"] = original_state_dict.pop("t5_yproj.weight")
56+
new_state_dict["time_embed.caption_proj.bias"] = original_state_dict.pop("t5_yproj.bias")
57+
58+
# Convert transformer blocks
59+
num_layers = 48
60+
for i in range(num_layers):
61+
block_prefix = f"transformer_blocks.{i}."
62+
old_prefix = f"blocks.{i}."
63+
64+
# norm1
65+
new_state_dict[block_prefix + "norm1.linear.weight"] = original_state_dict.pop(old_prefix + "mod_x.weight")
66+
new_state_dict[block_prefix + "norm1.linear.bias"] = original_state_dict.pop(old_prefix + "mod_x.bias")
67+
if i < num_layers - 1:
68+
new_state_dict[block_prefix + "norm1_context.linear.weight"] = original_state_dict.pop(
69+
old_prefix + "mod_y.weight"
70+
)
71+
new_state_dict[block_prefix + "norm1_context.linear.bias"] = original_state_dict.pop(
72+
old_prefix + "mod_y.bias"
73+
)
74+
else:
75+
new_state_dict[block_prefix + "norm1_context.weight"] = original_state_dict.pop(
76+
old_prefix + "mod_y.weight"
77+
)
78+
new_state_dict[block_prefix + "norm1_context.bias"] = original_state_dict.pop(old_prefix + "mod_y.bias")
79+
80+
# Visual attention
81+
qkv_weight = original_state_dict.pop(old_prefix + "attn.qkv_x.weight")
82+
q, k, v = qkv_weight.chunk(3, dim=0)
83+
84+
new_state_dict[block_prefix + "attn1.to_q.weight"] = q
85+
new_state_dict[block_prefix + "attn1.to_k.weight"] = k
86+
new_state_dict[block_prefix + "attn1.to_v.weight"] = v
87+
new_state_dict[block_prefix + "attn1.norm_q.weight"] = original_state_dict.pop(
88+
old_prefix + "attn.q_norm_x.weight"
89+
)
90+
new_state_dict[block_prefix + "attn1.norm_k.weight"] = original_state_dict.pop(
91+
old_prefix + "attn.k_norm_x.weight"
92+
)
93+
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = original_state_dict.pop(
94+
old_prefix + "attn.proj_x.weight"
95+
)
96+
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = original_state_dict.pop(old_prefix + "attn.proj_x.bias")
97+
98+
# Context attention
99+
qkv_weight = original_state_dict.pop(old_prefix + "attn.qkv_y.weight")
100+
q, k, v = qkv_weight.chunk(3, dim=0)
101+
102+
new_state_dict[block_prefix + "attn1.to_context_q.weight"] = q
103+
new_state_dict[block_prefix + "attn1.to_context_k.weight"] = k
104+
new_state_dict[block_prefix + "attn1.to_context_v.weight"] = v
105+
new_state_dict[block_prefix + "attn1.norm_context_q.weight"] = original_state_dict.pop(
106+
old_prefix + "attn.q_norm_y.weight"
107+
)
108+
new_state_dict[block_prefix + "attn1.norm_context_k.weight"] = original_state_dict.pop(
109+
old_prefix + "attn.k_norm_y.weight"
110+
)
111+
if i < num_layers - 1:
112+
new_state_dict[block_prefix + "attn1.to_context_out.0.weight"] = original_state_dict.pop(
113+
old_prefix + "attn.proj_y.weight"
114+
)
115+
new_state_dict[block_prefix + "attn1.to_context_out.0.bias"] = original_state_dict.pop(
116+
old_prefix + "attn.proj_y.bias"
117+
)
118+
119+
# MLP
120+
new_state_dict[block_prefix + "ff.net.0.proj.weight"] = original_state_dict.pop(old_prefix + "mlp_x.w1.weight")
121+
new_state_dict[block_prefix + "ff.net.2.weight"] = original_state_dict.pop(old_prefix + "mlp_x.w2.weight")
122+
if i < num_layers - 1:
123+
new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = original_state_dict.pop(
124+
old_prefix + "mlp_y.w1.weight"
125+
)
126+
new_state_dict[block_prefix + "ff_context.net.2.weight"] = original_state_dict.pop(
127+
old_prefix + "mlp_y.w2.weight"
128+
)
129+
130+
# Output layers
131+
new_state_dict["norm_out.linear.weight"] = original_state_dict.pop("final_layer.mod.weight")
132+
new_state_dict["norm_out.linear.bias"] = original_state_dict.pop("final_layer.mod.bias")
133+
new_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight")
134+
new_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias")
135+
136+
new_state_dict["pos_frequencies"] = original_state_dict.pop("pos_frequencies")
137+
138+
print("Remaining Keys:", original_state_dict.keys())
139+
140+
return new_state_dict
141+
142+
143+
# def convert_mochi_vae_checkpoint_to_diffusers(ckpt_path, vae_config):
144+
# original_state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
145+
# return convert_ldm_vae_checkpoint(original_state_dict, vae_config)
146+
147+
148+
def main(args):
149+
if args.dtype is None:
150+
dtype = None
151+
if args.dtype == "fp16":
152+
dtype = torch.float16
153+
elif args.dtype == "bf16":
154+
dtype = torch.bfloat16
155+
elif args.dtype == "fp32":
156+
dtype = torch.float32
157+
else:
158+
raise ValueError(f"Unsupported dtype: {args.dtype}")
159+
160+
transformer = None
161+
vae = None
162+
163+
if args.transformer_checkpoint_path is not None:
164+
converted_transformer_state_dict = convert_mochi_transformer_checkpoint_to_diffusers(
165+
args.transformer_checkpoint_path
166+
)
167+
transformer = MochiTransformer3DModel()
168+
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
169+
if dtype is not None:
170+
# Original checkpoint data type will be preserved
171+
transformer = transformer.to(dtype=dtype)
172+
173+
# text_encoder_id = "google/t5-v1_1-xxl"
174+
# tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
175+
# text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
176+
177+
# # Apparently, the conversion does not work anymore without this :shrug:
178+
# for param in text_encoder.parameters():
179+
# param.data = param.data.contiguous()
180+
181+
transformer.save_pretrained("/raid/aryan/mochi-diffusers", subfolder="transformer")
182+
183+
184+
if __name__ == "__main__":
185+
main(args)

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@
100100
"Kandinsky3UNet",
101101
"LatteTransformer3DModel",
102102
"LuminaNextDiT2DModel",
103+
"MochiTransformer3DModel",
103104
"ModelMixin",
104105
"MotionAdapter",
105106
"MultiAdapter",
@@ -579,6 +580,7 @@
579580
Kandinsky3UNet,
580581
LatteTransformer3DModel,
581582
LuminaNextDiT2DModel,
583+
MochiTransformer3DModel,
582584
ModelMixin,
583585
MotionAdapter,
584586
MultiAdapter,

src/diffusers/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
5757
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
5858
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
59+
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
5960
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
6061
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
6162
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
@@ -106,6 +107,7 @@
106107
HunyuanDiT2DModel,
107108
LatteTransformer3DModel,
108109
LuminaNextDiT2DModel,
110+
MochiTransformer3DModel,
109111
PixArtTransformer2DModel,
110112
PriorTransformer,
111113
SD3Transformer2DModel,

src/diffusers/models/attention_processor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -771,11 +771,14 @@ def __init__(
771771
nn.Linear(self.inner_dim, self.out_dim)
772772
])
773773

774-
self.to_context_out = None
775774
if out_context_dim is not None:
776775
self.to_context_out = nn.ModuleList([
777776
nn.Linear(self.inner_dim, out_context_dim)
778777
])
778+
else:
779+
self.to_context_out = nn.ModuleList([
780+
nn.Identity()
781+
])
779782

780783
if processor is None:
781784
processor = AsymmetricAttnProcessor2_0()

src/diffusers/models/transformers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,6 @@
1616
from .transformer_2d import Transformer2DModel
1717
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
1818
from .transformer_flux import FluxTransformer2DModel
19+
from .transformer_mochi import MochiTransformer3DModel
1920
from .transformer_sd3 import SD3Transformer2DModel
2021
from .transformer_temporal import TransformerTemporalModel

src/diffusers/models/transformers/transformer_mochi.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def __init__(
5757
else:
5858
self.norm1_context = nn.Linear(dim, pooled_projection_dim)
5959

60-
self.attn = AsymmetricAttention(
60+
self.attn1 = AsymmetricAttention(
6161
query_dim=dim,
6262
query_context_dim=pooled_projection_dim,
6363
num_attention_heads=num_attention_heads,
@@ -66,7 +66,7 @@ def __init__(
6666
out_context_dim=None if context_pre_only else pooled_projection_dim,
6767
qk_norm=qk_norm,
6868
eps=1e-6,
69-
elementwise_affine=False,
69+
elementwise_affine=True,
7070
processor=AsymmetricAttnProcessor2_0(),
7171
)
7272

@@ -100,7 +100,7 @@ def forward(
100100
else:
101101
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states)
102102

103-
attn_hidden_states, context_attn_hidden_states = self.attn(
103+
attn_hidden_states, context_attn_hidden_states = self.attn1(
104104
hidden_states=norm_hidden_states,
105105
encoder_hidden_states=norm_encoder_hidden_states,
106106
image_rotary_emb=image_rotary_emb,
@@ -127,7 +127,7 @@ def forward(
127127

128128

129129
@maybe_allow_in_graph
130-
class MochiTransformer3D(ModelMixin, ConfigMixin):
130+
class MochiTransformer3DModel(ModelMixin, ConfigMixin):
131131
_supports_gradient_checkpointing = True
132132

133133
@register_to_config

src/diffusers/utils/dummy_pt_objects.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,21 @@ def from_pretrained(cls, *args, **kwargs):
347347
requires_backends(cls, ["torch"])
348348

349349

350+
class MochiTransformer3DModel(metaclass=DummyObject):
351+
_backends = ["torch"]
352+
353+
def __init__(self, *args, **kwargs):
354+
requires_backends(self, ["torch"])
355+
356+
@classmethod
357+
def from_config(cls, *args, **kwargs):
358+
requires_backends(cls, ["torch"])
359+
360+
@classmethod
361+
def from_pretrained(cls, *args, **kwargs):
362+
requires_backends(cls, ["torch"])
363+
364+
350365
class ModelMixin(metaclass=DummyObject):
351366
_backends = ["torch"]
352367

0 commit comments

Comments
 (0)