44import torch
55from accelerate import init_empty_weights
66from safetensors .torch import load_file
7+ from transformers import T5EncoderModel , T5Tokenizer
78
8- # from transformers import T5EncoderModel, T5Tokenizer
9- from diffusers import MochiTransformer3DModel
9+ from diffusers import AutoencoderKLMochi , FlowMatchEulerDiscreteScheduler , MochiPipeline , MochiTransformer3DModel
1010from diffusers .utils .import_utils import is_accelerate_available
1111
1212
1616
1717parser = argparse .ArgumentParser ()
1818parser .add_argument ("--transformer_checkpoint_path" , default = None , type = str )
19- # parser.add_argument("--vae_checkpoint_path", default=None, type=str)
19+ parser .add_argument ("--vae_checkpoint_path" , default = None , type = str )
2020parser .add_argument ("--output_path" , required = True , type = str )
2121parser .add_argument ("--push_to_hub" , action = "store_true" , default = False , help = "Whether to push to HF Hub after saving" )
2222parser .add_argument ("--text_encoder_cache_dir" , type = str , default = None , help = "Path to text encoder cache directory" )
@@ -144,9 +144,106 @@ def convert_mochi_transformer_checkpoint_to_diffusers(ckpt_path):
144144 return new_state_dict
145145
146146
147- # def convert_mochi_vae_checkpoint_to_diffusers(ckpt_path, vae_config):
148- # original_state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
149- # return convert_ldm_vae_checkpoint(original_state_dict, vae_config)
147+ def convert_mochi_decoder_state_dict_to_diffusers (ckpt_path ):
148+ original_state_dict = load_file (ckpt_path , device = "cpu" )
149+
150+ new_state_dict = {}
151+ prefix = "decoder."
152+
153+ # Convert conv_in
154+ new_state_dict [f"{ prefix } conv_in.weight" ] = original_state_dict ["blocks.0.0.weight" ]
155+ new_state_dict [f"{ prefix } conv_in.bias" ] = original_state_dict ["blocks.0.0.bias" ]
156+
157+ # Convert block_in (MochiMidBlock3D)
158+ for i in range (3 ): # layers_per_block[-1] = 3
159+ new_state_dict [f"{ prefix } block_in.resnets.{ i } .norm1.norm_layer.weight" ] = original_state_dict [
160+ f"blocks.0.{ i + 1 } .stack.0.weight"
161+ ]
162+ new_state_dict [f"{ prefix } block_in.resnets.{ i } .norm1.norm_layer.bias" ] = original_state_dict [
163+ f"blocks.0.{ i + 1 } .stack.0.bias"
164+ ]
165+ new_state_dict [f"{ prefix } block_in.resnets.{ i } .conv1.conv.weight" ] = original_state_dict [
166+ f"blocks.0.{ i + 1 } .stack.2.weight"
167+ ]
168+ new_state_dict [f"{ prefix } block_in.resnets.{ i } .conv1.conv.bias" ] = original_state_dict [
169+ f"blocks.0.{ i + 1 } .stack.2.bias"
170+ ]
171+ new_state_dict [f"{ prefix } block_in.resnets.{ i } .norm2.norm_layer.weight" ] = original_state_dict [
172+ f"blocks.0.{ i + 1 } .stack.3.weight"
173+ ]
174+ new_state_dict [f"{ prefix } block_in.resnets.{ i } .norm2.norm_layer.bias" ] = original_state_dict [
175+ f"blocks.0.{ i + 1 } .stack.3.bias"
176+ ]
177+ new_state_dict [f"{ prefix } block_in.resnets.{ i } .conv2.conv.weight" ] = original_state_dict [
178+ f"blocks.0.{ i + 1 } .stack.5.weight"
179+ ]
180+ new_state_dict [f"{ prefix } block_in.resnets.{ i } .conv2.conv.bias" ] = original_state_dict [
181+ f"blocks.0.{ i + 1 } .stack.5.bias"
182+ ]
183+
184+ # Convert up_blocks (MochiUpBlock3D)
185+ up_block_layers = [6 , 4 , 3 ] # layers_per_block[-2], layers_per_block[-3], layers_per_block[-4]
186+ for block in range (3 ):
187+ for i in range (up_block_layers [block ]):
188+ new_state_dict [f"{ prefix } up_blocks.{ block } .resnets.{ i } .norm1.norm_layer.weight" ] = original_state_dict [
189+ f"blocks.{ block + 1 } .blocks.{ i } .stack.0.weight"
190+ ]
191+ new_state_dict [f"{ prefix } up_blocks.{ block } .resnets.{ i } .norm1.norm_layer.bias" ] = original_state_dict [
192+ f"blocks.{ block + 1 } .blocks.{ i } .stack.0.bias"
193+ ]
194+ new_state_dict [f"{ prefix } up_blocks.{ block } .resnets.{ i } .conv1.conv.weight" ] = original_state_dict [
195+ f"blocks.{ block + 1 } .blocks.{ i } .stack.2.weight"
196+ ]
197+ new_state_dict [f"{ prefix } up_blocks.{ block } .resnets.{ i } .conv1.conv.bias" ] = original_state_dict [
198+ f"blocks.{ block + 1 } .blocks.{ i } .stack.2.bias"
199+ ]
200+ new_state_dict [f"{ prefix } up_blocks.{ block } .resnets.{ i } .norm2.norm_layer.weight" ] = original_state_dict [
201+ f"blocks.{ block + 1 } .blocks.{ i } .stack.3.weight"
202+ ]
203+ new_state_dict [f"{ prefix } up_blocks.{ block } .resnets.{ i } .norm2.norm_layer.bias" ] = original_state_dict [
204+ f"blocks.{ block + 1 } .blocks.{ i } .stack.3.bias"
205+ ]
206+ new_state_dict [f"{ prefix } up_blocks.{ block } .resnets.{ i } .conv2.conv.weight" ] = original_state_dict [
207+ f"blocks.{ block + 1 } .blocks.{ i } .stack.5.weight"
208+ ]
209+ new_state_dict [f"{ prefix } up_blocks.{ block } .resnets.{ i } .conv2.conv.bias" ] = original_state_dict [
210+ f"blocks.{ block + 1 } .blocks.{ i } .stack.5.bias"
211+ ]
212+ new_state_dict [f"{ prefix } up_blocks.{ block } .proj.weight" ] = original_state_dict [f"blocks.{ block + 1 } .proj.weight" ]
213+ new_state_dict [f"{ prefix } up_blocks.{ block } .proj.bias" ] = original_state_dict [f"blocks.{ block + 1 } .proj.bias" ]
214+
215+ # Convert block_out (MochiMidBlock3D)
216+ for i in range (3 ): # layers_per_block[0] = 3
217+ new_state_dict [f"{ prefix } block_out.resnets.{ i } .norm1.norm_layer.weight" ] = original_state_dict [
218+ f"blocks.4.{ i } .stack.0.weight"
219+ ]
220+ new_state_dict [f"{ prefix } block_out.resnets.{ i } .norm1.norm_layer.bias" ] = original_state_dict [
221+ f"blocks.4.{ i } .stack.0.bias"
222+ ]
223+ new_state_dict [f"{ prefix } block_out.resnets.{ i } .conv1.conv.weight" ] = original_state_dict [
224+ f"blocks.4.{ i } .stack.2.weight"
225+ ]
226+ new_state_dict [f"{ prefix } block_out.resnets.{ i } .conv1.conv.bias" ] = original_state_dict [
227+ f"blocks.4.{ i } .stack.2.bias"
228+ ]
229+ new_state_dict [f"{ prefix } block_out.resnets.{ i } .norm2.norm_layer.weight" ] = original_state_dict [
230+ f"blocks.4.{ i } .stack.3.weight"
231+ ]
232+ new_state_dict [f"{ prefix } block_out.resnets.{ i } .norm2.norm_layer.bias" ] = original_state_dict [
233+ f"blocks.4.{ i } .stack.3.bias"
234+ ]
235+ new_state_dict [f"{ prefix } block_out.resnets.{ i } .conv2.conv.weight" ] = original_state_dict [
236+ f"blocks.4.{ i } .stack.5.weight"
237+ ]
238+ new_state_dict [f"{ prefix } block_out.resnets.{ i } .conv2.conv.bias" ] = original_state_dict [
239+ f"blocks.4.{ i } .stack.5.bias"
240+ ]
241+
242+ # Convert conv_out (Conv1x1)
243+ new_state_dict [f"{ prefix } conv_out.weight" ] = original_state_dict ["output_proj.weight" ]
244+ new_state_dict [f"{ prefix } conv_out.bias" ] = original_state_dict ["output_proj.bias" ]
245+
246+ return new_state_dict
150247
151248
152249def main (args ):
@@ -162,7 +259,7 @@ def main(args):
162259 raise ValueError (f"Unsupported dtype: { args .dtype } " )
163260
164261 transformer = None
165- # vae = None
262+ vae = None
166263
167264 if args .transformer_checkpoint_path is not None :
168265 converted_transformer_state_dict = convert_mochi_transformer_checkpoint_to_diffusers (
@@ -171,18 +268,31 @@ def main(args):
171268 transformer = MochiTransformer3DModel ()
172269 transformer .load_state_dict (converted_transformer_state_dict , strict = True )
173270 if dtype is not None :
174- # Original checkpoint data type will be preserved
175271 transformer = transformer .to (dtype = dtype )
176272
177- # text_encoder_id = "google/t5-v1_1-xxl"
178- # tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
179- # text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
180-
181- # # Apparently, the conversion does not work anymore without this :shrug:
182- # for param in text_encoder.parameters():
183- # param.data = param.data.contiguous()
184-
185- transformer .save_pretrained ("/raid/aryan/mochi-diffusers" , subfolder = "transformer" )
273+ if args .vae_checkpoint_path is not None :
274+ vae = AutoencoderKLMochi (latent_channels = 12 , out_channels = 3 )
275+ converted_vae_state_dict = convert_mochi_decoder_state_dict_to_diffusers (args .vae_checkpoint_path )
276+ vae .load_state_dict (converted_vae_state_dict , strict = True )
277+ if dtype is not None :
278+ vae = vae .to (dtype = dtype )
279+
280+ text_encoder_id = "google/t5-v1_1-xxl"
281+ tokenizer = T5Tokenizer .from_pretrained (text_encoder_id , model_max_length = TOKENIZER_MAX_LENGTH )
282+ text_encoder = T5EncoderModel .from_pretrained (text_encoder_id , cache_dir = args .text_encoder_cache_dir )
283+
284+ # Apparently, the conversion does not work anymore without this :shrug:
285+ for param in text_encoder .parameters ():
286+ param .data = param .data .contiguous ()
287+
288+ pipe = MochiPipeline (
289+ scheduler = FlowMatchEulerDiscreteScheduler (),
290+ vae = vae ,
291+ text_encoder = text_encoder ,
292+ tokenizer = tokenizer ,
293+ transformer = transformer ,
294+ )
295+ pipe .save_pretrained (args .output_path , safe_serialization = True , max_shard_size = "5GB" , push_to_hub = args .push_to_hub )
186296
187297
188298if __name__ == "__main__" :
0 commit comments