77from huggingface_hub import snapshot_download
88from transformers import T5EncoderModel , T5TokenizerFast
99
10- from diffusers import AutoencoderKLCosmos , CosmosTextToWorldPipeline , CosmosTransformer3DModel , EDMEulerScheduler
10+ from diffusers import (
11+ AutoencoderKLCosmos ,
12+ AutoencoderKLWan ,
13+ CosmosTextToImagePipeline ,
14+ CosmosTextToWorldPipeline ,
15+ CosmosTransformer3DModel ,
16+ EDMEulerScheduler ,
17+ FlowMatchEulerEDMCosmos2_0Scheduler ,
18+ )
1119
1220
1321def remove_keys_ (key : str , state_dict : Dict [str , Any ]):
@@ -29,7 +37,7 @@ def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
2937 state_dict [new_key ] = state_dict .pop (key )
3038
3139
32- TRANSFORMER_KEYS_RENAME_DICT = {
40+ TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0 = {
3341 "t_embedder.1" : "time_embed.t_embedder" ,
3442 "affline_norm" : "time_embed.norm" ,
3543 ".blocks.0.block.attn" : ".attn1" ,
@@ -56,14 +64,53 @@ def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
5664 "final_layer.linear" : "proj_out" ,
5765}
5866
59- TRANSFORMER_SPECIAL_KEYS_REMAP = {
67+ TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0 = {
6068 "blocks.block" : rename_transformer_blocks_ ,
6169 "logvar.0.freqs" : remove_keys_ ,
6270 "logvar.0.phases" : remove_keys_ ,
6371 "logvar.1.weight" : remove_keys_ ,
6472 "pos_embedder.seq" : remove_keys_ ,
6573}
6674
75+ TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 = {
76+ "t_embedder.1" : "time_embed.t_embedder" ,
77+ "t_embedding_norm" : "time_embed.norm" ,
78+ "blocks" : "transformer_blocks" ,
79+ "adaln_modulation_self_attn.1" : "norm1.linear_1" ,
80+ "adaln_modulation_self_attn.2" : "norm1.linear_2" ,
81+ "adaln_modulation_cross_attn.1" : "norm2.linear_1" ,
82+ "adaln_modulation_cross_attn.2" : "norm2.linear_2" ,
83+ "adaln_modulation_mlp.1" : "norm3.linear_1" ,
84+ "adaln_modulation_mlp.2" : "norm3.linear_2" ,
85+ "self_attn" : "attn1" ,
86+ "cross_attn" : "attn2" ,
87+ "q_proj" : "to_q" ,
88+ "k_proj" : "to_k" ,
89+ "v_proj" : "to_v" ,
90+ "output_proj" : "to_out.0" ,
91+ "q_norm" : "norm_q" ,
92+ "k_norm" : "norm_k" ,
93+ "mlp.layer1" : "ff.net.0.proj" ,
94+ "mlp.layer2" : "ff.net.2" ,
95+ "x_embedder.proj.1" : "patch_embed.proj" ,
96+ # "extra_pos_embedder": "learnable_pos_embed",
97+ "final_layer.adaln_modulation.1" : "norm_out.linear_1" ,
98+ "final_layer.adaln_modulation.2" : "norm_out.linear_2" ,
99+ "final_layer.linear" : "proj_out" ,
100+ }
101+
102+ TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0 = {
103+ "accum_video_sample_counter" : remove_keys_ ,
104+ "accum_image_sample_counter" : remove_keys_ ,
105+ "accum_iteration" : remove_keys_ ,
106+ "accum_train_in_hours" : remove_keys_ ,
107+ "pos_embedder.seq" : remove_keys_ ,
108+ "pos_embedder.dim_spatial_range" : remove_keys_ ,
109+ "pos_embedder.dim_temporal_range" : remove_keys_ ,
110+ "_extra_state" : remove_keys_ ,
111+ }
112+
113+
67114TRANSFORMER_CONFIGS = {
68115 "Cosmos-1.0-Diffusion-7B-Text2World" : {
69116 "in_channels" : 16 ,
@@ -125,6 +172,21 @@ def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
125172 "concat_padding_mask" : True ,
126173 "extra_pos_embed_type" : "learnable" ,
127174 },
175+ "Cosmos-2.0-Diffusion-2B-Text2Image" : {
176+ "in_channels" : 16 ,
177+ "out_channels" : 16 ,
178+ "num_attention_heads" : 16 ,
179+ "attention_head_dim" : 128 ,
180+ "num_layers" : 28 ,
181+ "mlp_ratio" : 4.0 ,
182+ "text_embed_dim" : 1024 ,
183+ "adaln_lora_dim" : 256 ,
184+ "max_size" : (128 , 240 , 240 ),
185+ "patch_size" : (1 , 2 , 2 ),
186+ "rope_scale" : (1.0 , 1.0 , 1.0 ),
187+ "concat_padding_mask" : True ,
188+ "extra_pos_embed_type" : None ,
189+ },
128190}
129191
130192VAE_KEYS_RENAME_DICT = {
@@ -216,9 +278,18 @@ def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
216278 return state_dict
217279
218280
219- def convert_transformer (transformer_type : str , ckpt_path : str ):
281+ def convert_transformer (transformer_type : str , ckpt_path : str , weights_only : bool = True ):
220282 PREFIX_KEY = "net."
221- original_state_dict = get_state_dict (torch .load (ckpt_path , map_location = "cpu" , weights_only = True ))
283+ original_state_dict = get_state_dict (torch .load (ckpt_path , map_location = "cpu" , weights_only = weights_only ))
284+
285+ if "Cosmos-1.0" in transformer_type :
286+ TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0
287+ TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0
288+ elif "Cosmos-2.0" in transformer_type :
289+ TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
290+ TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
291+ else :
292+ assert False
222293
223294 with init_empty_weights ():
224295 config = TRANSFORMER_CONFIGS [transformer_type ]
@@ -281,13 +352,66 @@ def convert_vae(vae_type: str):
281352 return vae
282353
283354
355+ def save_pipeline_cosmos_1_0 (args , transformer , vae , dtype ):
356+ text_encoder = T5EncoderModel .from_pretrained (args .text_encoder_path , torch_dtype = dtype )
357+ tokenizer = T5TokenizerFast .from_pretrained (args .tokenizer_path )
358+ # The original code initializes EDM config with sigma_min=0.0002, but does not make use of it anywhere directly.
359+ # So, the sigma_min values that is used is the default value of 0.002.
360+ scheduler = EDMEulerScheduler (
361+ sigma_min = 0.002 ,
362+ sigma_max = 80 ,
363+ sigma_data = 0.5 ,
364+ sigma_schedule = "karras" ,
365+ num_train_timesteps = 1000 ,
366+ prediction_type = "epsilon" ,
367+ rho = 7.0 ,
368+ final_sigmas_type = "sigma_min" ,
369+ )
370+
371+ pipe = CosmosTextToWorldPipeline (
372+ text_encoder = text_encoder ,
373+ tokenizer = tokenizer ,
374+ transformer = transformer ,
375+ vae = vae ,
376+ scheduler = scheduler ,
377+ )
378+ pipe .save_pretrained (args .output_path , safe_serialization = True , max_shard_size = "5GB" )
379+
380+
381+ def save_pipeline_cosmos_2_0 (args , transformer , vae , dtype ):
382+ text_encoder = T5EncoderModel .from_pretrained (args .text_encoder_path , torch_dtype = dtype )
383+ tokenizer = T5TokenizerFast .from_pretrained (args .tokenizer_path )
384+
385+ scheduler = FlowMatchEulerEDMCosmos2_0Scheduler (
386+ sigma_min = 0.0002 ,
387+ sigma_max = 80 ,
388+ sigma_data = 1.0 ,
389+ sigma_schedule = "karras" ,
390+ num_train_timesteps = 1000 ,
391+ prediction_type = "epsilon" ,
392+ rho = 7.0 ,
393+ final_sigmas_type = "sigma_min" ,
394+ )
395+
396+ pipe = CosmosTextToImagePipeline (
397+ text_encoder = text_encoder ,
398+ tokenizer = tokenizer ,
399+ transformer = transformer ,
400+ vae = vae ,
401+ scheduler = scheduler ,
402+ )
403+ pipe .save_pretrained (args .output_path , safe_serialization = True , max_shard_size = "5GB" )
404+
405+
284406def get_args ():
285407 parser = argparse .ArgumentParser ()
286408 parser .add_argument ("--transformer_type" , type = str , default = None , choices = list (TRANSFORMER_CONFIGS .keys ()))
287409 parser .add_argument (
288410 "--transformer_ckpt_path" , type = str , default = None , help = "Path to original transformer checkpoint"
289411 )
290- parser .add_argument ("--vae_type" , type = str , default = None , choices = list (VAE_CONFIGS .keys ()), help = "Type of VAE" )
412+ parser .add_argument (
413+ "--vae_type" , type = str , default = None , choices = ["none" , * list (VAE_CONFIGS .keys ())], help = "Type of VAE"
414+ )
291415 parser .add_argument ("--text_encoder_path" , type = str , default = "google-t5/t5-11b" )
292416 parser .add_argument ("--tokenizer_path" , type = str , default = "google-t5/t5-11b" )
293417 parser .add_argument ("--save_pipeline" , action = "store_true" )
@@ -316,37 +440,26 @@ def get_args():
316440 assert args .tokenizer_path is not None
317441
318442 if args .transformer_ckpt_path is not None :
319- transformer = convert_transformer (args .transformer_type , args .transformer_ckpt_path )
443+ weights_only = "Cosmos-1.0" in args .transformer_type
444+ transformer = convert_transformer (args .transformer_type , args .transformer_ckpt_path , weights_only )
320445 transformer = transformer .to (dtype = dtype )
321446 if not args .save_pipeline :
322447 transformer .save_pretrained (args .output_path , safe_serialization = True , max_shard_size = "5GB" )
323448
324449 if args .vae_type is not None :
325- vae = convert_vae (args .vae_type )
450+ if "Cosmos-1.0" in args .transformer_type :
451+ vae = convert_vae (args .vae_type )
452+ else :
453+ vae = AutoencoderKLWan .from_pretrained (
454+ "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" , subfolder = "vae" , torch_dtype = torch .float32
455+ )
326456 if not args .save_pipeline :
327457 vae .save_pretrained (args .output_path , safe_serialization = True , max_shard_size = "5GB" )
328458
329459 if args .save_pipeline :
330- text_encoder = T5EncoderModel .from_pretrained (args .text_encoder_path , torch_dtype = dtype )
331- tokenizer = T5TokenizerFast .from_pretrained (args .tokenizer_path )
332- # The original code initializes EDM config with sigma_min=0.0002, but does not make use of it anywhere directly.
333- # So, the sigma_min values that is used is the default value of 0.002.
334- scheduler = EDMEulerScheduler (
335- sigma_min = 0.002 ,
336- sigma_max = 80 ,
337- sigma_data = 0.5 ,
338- sigma_schedule = "karras" ,
339- num_train_timesteps = 1000 ,
340- prediction_type = "epsilon" ,
341- rho = 7.0 ,
342- final_sigmas_type = "sigma_min" ,
343- )
344-
345- pipe = CosmosTextToWorldPipeline (
346- text_encoder = text_encoder ,
347- tokenizer = tokenizer ,
348- transformer = transformer ,
349- vae = vae ,
350- scheduler = scheduler ,
351- )
352- pipe .save_pretrained (args .output_path , safe_serialization = True , max_shard_size = "5GB" )
460+ if "Cosmos-1.0" in args .transformer_type :
461+ save_pipeline_cosmos_1_0 (args , transformer , vae , dtype )
462+ elif "Cosmos-2.0" in args .transformer_type :
463+ save_pipeline_cosmos_2_0 (args , transformer , vae , dtype )
464+ else :
465+ assert False
0 commit comments