11import  argparse 
2+ import  pathlib 
23from  typing  import  Any , Dict 
34
45import  torch 
56from  accelerate  import  init_empty_weights 
7+ from  huggingface_hub  import  snapshot_download 
68from  transformers  import  T5EncoderModel , T5TokenizerFast 
79
8- from  diffusers  import  CosmosTransformer3DModel , EDMEulerScheduler 
10+ from  diffusers  import  AutoencoderKLCosmos ,  CosmosTransformer3DModel , EDMEulerScheduler 
911
1012
1113def  remove_keys_ (key : str , state_dict : Dict [str , Any ]):
@@ -63,10 +65,81 @@ def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
6365}
6466
6567VAE_KEYS_RENAME_DICT  =  {
66-     "conv3d" : "conv" ,
68+     "down.0" : "down_blocks.0" ,
69+     "down.1" : "down_blocks.1" ,
70+     "down.2" : "down_blocks.2" ,
71+     "up.0" : "up_blocks.2" ,
72+     "up.1" : "up_blocks.1" ,
73+     "up.2" : "up_blocks.0" ,
74+     ".block." : ".resnets." ,
75+     "downsample" : "downsamplers.0" ,
76+     "upsample" : "upsamplers.0" ,
77+     "mid.block_1" : "mid_block.resnets.0" ,
78+     "mid.attn_1.0" : "mid_block.attentions.0" ,
79+     "mid.attn_1.1" : "mid_block.temp_attentions.0" ,
80+     "mid.block_2" : "mid_block.resnets.1" ,
81+     ".q.conv3d" : ".to_q" ,
82+     ".k.conv3d" : ".to_k" ,
83+     ".v.conv3d" : ".to_v" ,
84+     ".proj_out.conv3d" : ".to_out.0" ,
85+     ".0.conv3d" : ".conv_s" ,
86+     ".1.conv3d" : ".conv_t" ,
87+     "conv1.conv3d" : "conv1" ,
88+     "conv2.conv3d" : "conv2" ,
89+     "conv3.conv3d" : "conv3" ,
90+     "nin_shortcut.conv3d" : "conv_shortcut" ,
91+     "quant_conv.conv3d" : "quant_conv" ,
92+     "post_quant_conv.conv3d" : "post_quant_conv" ,
6793}
6894
69- VAE_SPECIAL_KEYS_REMAP  =  {}
95+ VAE_SPECIAL_KEYS_REMAP  =  {
96+     "wavelets" : remove_keys_ ,
97+     "_arange" : remove_keys_ ,
98+     "patch_size_buffer" : remove_keys_ ,
99+ }
100+ 
101+ VAE_CONFIGS  =  {
102+     "CV8x8x8-0.1" : {
103+         "name" : "nvidia/Cosmos-0.1-Tokenizer-CV8x8x8" ,
104+         "diffusers_config" : {
105+             "in_channels" : 3 ,
106+             "out_channels" : 3 ,
107+             "latent_channels" : 16 ,
108+             "encoder_block_out_channels" : (128 , 256 , 512 , 512 ),
109+             "decode_block_out_channels" : (256 , 512 , 512 , 512 ),
110+             "attention_resolutions" : (32 ,),
111+             "resolution" : 1024 ,
112+             "num_layers" : 2 ,
113+             "patch_size" : 4 ,
114+             "patch_type" : "haar" ,
115+             "scaling_factor" : 1.0 ,
116+             "spatial_compression_ratio" : 8 ,
117+             "temporal_compression_ratio" : 8 ,
118+             "latents_mean" : None ,
119+             "latents_std" : None ,
120+         },
121+     },
122+     "CV8x8x8-1.0" : {
123+         "name" : "nvidia/Cosmos-1.0-Tokenizer-CV8x8x8" ,
124+         "diffusers_config" : {
125+             "in_channels" : 3 ,
126+             "out_channels" : 3 ,
127+             "latent_channels" : 16 ,
128+             "encoder_block_out_channels" : (128 , 256 , 512 , 512 ),
129+             "decode_block_out_channels" : (256 , 512 , 512 , 512 ),
130+             "attention_resolutions" : (32 ,),
131+             "resolution" : 1024 ,
132+             "num_layers" : 2 ,
133+             "patch_size" : 4 ,
134+             "patch_type" : "haar" ,
135+             "scaling_factor" : 1.0 ,
136+             "spatial_compression_ratio" : 8 ,
137+             "temporal_compression_ratio" : 8 ,
138+             "latents_mean" : None ,
139+             "latents_std" : None ,
140+         },
141+     },
142+ }
70143
71144
72145def  get_state_dict (saved_dict : Dict [str , Any ]) ->  Dict [str , Any ]:
@@ -105,36 +178,53 @@ def convert_transformer(ckpt_path: str):
105178    return  transformer 
106179
107180
108- # def convert_vae(ckpt_path: str): 
109- #     original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True)) 
181+ def  convert_vae (vae_type : str ):
182+     model_name  =  VAE_CONFIGS [vae_type ]["name" ]
183+     snapshot_directory  =  snapshot_download (model_name , repo_type = "model" )
184+     directory  =  pathlib .Path (snapshot_directory )
110185
111- #      with init_empty_weights(): 
112- #          vae = AutoencoderKLHunyuanVideo() 
186+     autoencoder_file   =   directory   /   "autoencoder.jit" 
187+     mean_std_file   =   directory   /   "mean_std.pt" 
113188
114- #      for key in list(original_state_dict.keys ()): 
115- #          new_key = key[:] 
116- #          for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): 
117- #              new_key = new_key.replace(replace_key, rename_key) 
118- #          update_state_dict_(original_state_dict, key, new_key )
189+     original_state_dict   =   torch . jit . load ( autoencoder_file . as_posix ()). state_dict () 
190+     if   mean_std_file . exists (): 
191+         mean_std   =   torch . load ( mean_std_file ,  map_location = "cpu" ,  weights_only = True ) 
192+     else : 
193+         mean_std   =  ( None ,  None )
119194
120- #     for key in list(original_state_dict.keys()): 
121- #         for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items(): 
122- #             if special_key not in key: 
123- #                 continue 
124- #             handler_fn_inplace(key, original_state_dict) 
195+     config  =  VAE_CONFIGS [vae_type ]["diffusers_config" ]
196+     config .update (
197+         {
198+             "latents_mean" : mean_std [0 ],
199+             "latents_std" : mean_std [1 ],
200+         }
201+     )
202+     vae  =  AutoencoderKLCosmos (** config )
125203
126- #     vae.load_state_dict(original_state_dict, strict=True, assign=True) 
127- #     return vae 
204+     for  key  in  list (original_state_dict .keys ()):
205+         new_key  =  key [:]
206+         for  replace_key , rename_key  in  VAE_KEYS_RENAME_DICT .items ():
207+             new_key  =  new_key .replace (replace_key , rename_key )
208+         update_state_dict_ (original_state_dict , key , new_key )
209+ 
210+     for  key  in  list (original_state_dict .keys ()):
211+         for  special_key , handler_fn_inplace  in  VAE_SPECIAL_KEYS_REMAP .items ():
212+             if  special_key  not  in key :
213+                 continue 
214+             handler_fn_inplace (key , original_state_dict )
215+ 
216+     vae .load_state_dict (original_state_dict , strict = True , assign = True )
217+     return  vae 
128218
129219
130220def  get_args ():
131221    parser  =  argparse .ArgumentParser ()
132222    parser .add_argument (
133223        "--transformer_ckpt_path" , type = str , default = None , help = "Path to original transformer checkpoint" 
134224    )
135-     parser .add_argument ("--vae_ckpt_path " , type = str , default = None , help = "Path to original  VAE checkpoint " )
136-     parser .add_argument ("--text_encoder_path" , type = str , default = None , help = "Path to original T5 checkpoint" )
137-     parser .add_argument ("--tokenizer_path" , type = str , default = None , help = "Path to original T5 tokenizer" )
225+     parser .add_argument ("--vae_type " , type = str , default = None , choices = list ( VAE_CONFIGS . keys ()),  help = "Type of  VAE" )
226+     parser .add_argument ("--text_encoder_path" , type = str , default = None , help = "Path or HF id  to original T5 checkpoint" )
227+     parser .add_argument ("--tokenizer_path" , type = str , default = None , help = "Path or HF id  to original T5 tokenizer" )
138228    parser .add_argument ("--save_pipeline" , action = "store_true" )
139229    parser .add_argument ("--output_path" , type = str , required = True , help = "Path where converted model should be saved" )
140230    parser .add_argument ("--dtype" , default = "bf16" , help = "Torch dtype to save the transformer in." )
@@ -155,7 +245,8 @@ def get_args():
155245    dtype  =  DTYPE_MAPPING [args .dtype ]
156246
157247    if  args .save_pipeline :
158-         assert  args .transformer_ckpt_path  is  not None  and  args .vae_ckpt_path  is  not None 
248+         assert  args .transformer_ckpt_path  is  not None 
249+         assert  args .vae_type  is  not None 
159250        assert  args .text_encoder_path  is  not None 
160251        assert  args .tokenizer_path  is  not None 
161252        assert  args .text_encoder_2_path  is  not None 
@@ -166,10 +257,10 @@ def get_args():
166257        if  not  args .save_pipeline :
167258            transformer .save_pretrained (args .output_path , safe_serialization = True , max_shard_size = "5GB" )
168259
169-     #  if args.vae_ckpt_path  is not None:
170-     #       vae = convert_vae(args.vae_ckpt_path )
171-     #      if not args.save_pipeline:
172-     #          vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
260+     if  args .vae_type  is  not None :
261+         vae  =  convert_vae (args .vae_type )
262+         if  not  args .save_pipeline :
263+             vae .save_pretrained (args .output_path , safe_serialization = True , max_shard_size = "5GB" )
173264
174265    if  args .save_pipeline :
175266        text_encoder  =  T5EncoderModel .from_pretrained (args .text_encoder_path , torch_dtype = dtype )
@@ -184,6 +275,7 @@ def get_args():
184275            num_train_timesteps = 1000 ,
185276            prediction_type = "epsilon" ,
186277            rho = 7.0 ,
278+             final_sigmas_type = "sigma_min" ,
187279        )
188280
189281    # if args.save_pipeline: 
0 commit comments