Skip to content

Commit 2798ed4

Browse files
committed
update conversion script
1 parent d41198c commit 2798ed4

File tree

2 files changed

+128
-18
lines changed

2 files changed

+128
-18
lines changed

scripts/convert_mochi_to_diffusers.py

Lines changed: 127 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
import torch
55
from accelerate import init_empty_weights
66
from safetensors.torch import load_file
7+
from transformers import T5EncoderModel, T5Tokenizer
78

8-
# from transformers import T5EncoderModel, T5Tokenizer
9-
from diffusers import MochiTransformer3DModel
9+
from diffusers import AutoencoderKLMochi, FlowMatchEulerDiscreteScheduler, MochiPipeline, MochiTransformer3DModel
1010
from diffusers.utils.import_utils import is_accelerate_available
1111

1212

@@ -16,7 +16,7 @@
1616

1717
parser = argparse.ArgumentParser()
1818
parser.add_argument("--transformer_checkpoint_path", default=None, type=str)
19-
# parser.add_argument("--vae_checkpoint_path", default=None, type=str)
19+
parser.add_argument("--vae_checkpoint_path", default=None, type=str)
2020
parser.add_argument("--output_path", required=True, type=str)
2121
parser.add_argument("--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving")
2222
parser.add_argument("--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory")
@@ -144,9 +144,106 @@ def convert_mochi_transformer_checkpoint_to_diffusers(ckpt_path):
144144
return new_state_dict
145145

146146

147-
# def convert_mochi_vae_checkpoint_to_diffusers(ckpt_path, vae_config):
148-
# original_state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
149-
# return convert_ldm_vae_checkpoint(original_state_dict, vae_config)
147+
def convert_mochi_decoder_state_dict_to_diffusers(ckpt_path):
148+
original_state_dict = load_file(ckpt_path, device="cpu")
149+
150+
new_state_dict = {}
151+
prefix = "decoder."
152+
153+
# Convert conv_in
154+
new_state_dict[f"{prefix}conv_in.weight"] = original_state_dict["blocks.0.0.weight"]
155+
new_state_dict[f"{prefix}conv_in.bias"] = original_state_dict["blocks.0.0.bias"]
156+
157+
# Convert block_in (MochiMidBlock3D)
158+
for i in range(3): # layers_per_block[-1] = 3
159+
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = original_state_dict[
160+
f"blocks.0.{i+1}.stack.0.weight"
161+
]
162+
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = original_state_dict[
163+
f"blocks.0.{i+1}.stack.0.bias"
164+
]
165+
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = original_state_dict[
166+
f"blocks.0.{i+1}.stack.2.weight"
167+
]
168+
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = original_state_dict[
169+
f"blocks.0.{i+1}.stack.2.bias"
170+
]
171+
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = original_state_dict[
172+
f"blocks.0.{i+1}.stack.3.weight"
173+
]
174+
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = original_state_dict[
175+
f"blocks.0.{i+1}.stack.3.bias"
176+
]
177+
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = original_state_dict[
178+
f"blocks.0.{i+1}.stack.5.weight"
179+
]
180+
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = original_state_dict[
181+
f"blocks.0.{i+1}.stack.5.bias"
182+
]
183+
184+
# Convert up_blocks (MochiUpBlock3D)
185+
up_block_layers = [6, 4, 3] # layers_per_block[-2], layers_per_block[-3], layers_per_block[-4]
186+
for block in range(3):
187+
for i in range(up_block_layers[block]):
188+
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"] = original_state_dict[
189+
f"blocks.{block+1}.blocks.{i}.stack.0.weight"
190+
]
191+
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = original_state_dict[
192+
f"blocks.{block+1}.blocks.{i}.stack.0.bias"
193+
]
194+
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.weight"] = original_state_dict[
195+
f"blocks.{block+1}.blocks.{i}.stack.2.weight"
196+
]
197+
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.bias"] = original_state_dict[
198+
f"blocks.{block+1}.blocks.{i}.stack.2.bias"
199+
]
200+
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"] = original_state_dict[
201+
f"blocks.{block+1}.blocks.{i}.stack.3.weight"
202+
]
203+
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = original_state_dict[
204+
f"blocks.{block+1}.blocks.{i}.stack.3.bias"
205+
]
206+
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.weight"] = original_state_dict[
207+
f"blocks.{block+1}.blocks.{i}.stack.5.weight"
208+
]
209+
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.bias"] = original_state_dict[
210+
f"blocks.{block+1}.blocks.{i}.stack.5.bias"
211+
]
212+
new_state_dict[f"{prefix}up_blocks.{block}.proj.weight"] = original_state_dict[f"blocks.{block+1}.proj.weight"]
213+
new_state_dict[f"{prefix}up_blocks.{block}.proj.bias"] = original_state_dict[f"blocks.{block+1}.proj.bias"]
214+
215+
# Convert block_out (MochiMidBlock3D)
216+
for i in range(3): # layers_per_block[0] = 3
217+
new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.weight"] = original_state_dict[
218+
f"blocks.4.{i}.stack.0.weight"
219+
]
220+
new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.bias"] = original_state_dict[
221+
f"blocks.4.{i}.stack.0.bias"
222+
]
223+
new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.weight"] = original_state_dict[
224+
f"blocks.4.{i}.stack.2.weight"
225+
]
226+
new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.bias"] = original_state_dict[
227+
f"blocks.4.{i}.stack.2.bias"
228+
]
229+
new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.weight"] = original_state_dict[
230+
f"blocks.4.{i}.stack.3.weight"
231+
]
232+
new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.bias"] = original_state_dict[
233+
f"blocks.4.{i}.stack.3.bias"
234+
]
235+
new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.weight"] = original_state_dict[
236+
f"blocks.4.{i}.stack.5.weight"
237+
]
238+
new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.bias"] = original_state_dict[
239+
f"blocks.4.{i}.stack.5.bias"
240+
]
241+
242+
# Convert conv_out (Conv1x1)
243+
new_state_dict[f"{prefix}conv_out.weight"] = original_state_dict["output_proj.weight"]
244+
new_state_dict[f"{prefix}conv_out.bias"] = original_state_dict["output_proj.bias"]
245+
246+
return new_state_dict
150247

151248

152249
def main(args):
@@ -162,7 +259,7 @@ def main(args):
162259
raise ValueError(f"Unsupported dtype: {args.dtype}")
163260

164261
transformer = None
165-
# vae = None
262+
vae = None
166263

167264
if args.transformer_checkpoint_path is not None:
168265
converted_transformer_state_dict = convert_mochi_transformer_checkpoint_to_diffusers(
@@ -171,18 +268,31 @@ def main(args):
171268
transformer = MochiTransformer3DModel()
172269
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
173270
if dtype is not None:
174-
# Original checkpoint data type will be preserved
175271
transformer = transformer.to(dtype=dtype)
176272

177-
# text_encoder_id = "google/t5-v1_1-xxl"
178-
# tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
179-
# text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
180-
181-
# # Apparently, the conversion does not work anymore without this :shrug:
182-
# for param in text_encoder.parameters():
183-
# param.data = param.data.contiguous()
184-
185-
transformer.save_pretrained("/raid/aryan/mochi-diffusers", subfolder="transformer")
273+
if args.vae_checkpoint_path is not None:
274+
vae = AutoencoderKLMochi(latent_channels=12, out_channels=3)
275+
converted_vae_state_dict = convert_mochi_decoder_state_dict_to_diffusers(args.vae_checkpoint_path)
276+
vae.load_state_dict(converted_vae_state_dict, strict=True)
277+
if dtype is not None:
278+
vae = vae.to(dtype=dtype)
279+
280+
text_encoder_id = "google/t5-v1_1-xxl"
281+
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
282+
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
283+
284+
# Apparently, the conversion does not work anymore without this :shrug:
285+
for param in text_encoder.parameters():
286+
param.data = param.data.contiguous()
287+
288+
pipe = MochiPipeline(
289+
scheduler=FlowMatchEulerDiscreteScheduler(),
290+
vae=vae,
291+
text_encoder=text_encoder,
292+
tokenizer=tokenizer,
293+
transformer=transformer,
294+
)
295+
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)
186296

187297

188298
if __name__ == "__main__":

src/diffusers/pipelines/mochi/pipeline_mochi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def __init__(
204204
self.vae_temporal_scale_factor = 6
205205
self.patch_size = 2
206206

207-
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
207+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_scale_factor)
208208
self.tokenizer_max_length = (
209209
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
210210
)

0 commit comments

Comments
 (0)