11import argparse
22from typing import Any , Dict
3+ from pathlib import Path
34
45import torch
6+ from accelerate import init_empty_weights
57from safetensors .torch import load_file
68from transformers import T5EncoderModel , T5Tokenizer
79
@@ -21,7 +23,9 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]):
2123 "k_norm" : "norm_k" ,
2224}
2325
24- TRANSFORMER_SPECIAL_KEYS_REMAP = {}
26+ TRANSFORMER_SPECIAL_KEYS_REMAP = {
27+ "vae" : remove_keys_ ,
28+ }
2529
2630VAE_KEYS_RENAME_DICT = {
2731 # decoder
@@ -54,10 +58,33 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]):
5458 "per_channel_statistics.std-of-means" : "latents_std" ,
5559}
5660
61+ VAE_091_RENAME_DICT = {
62+ # decoder
63+ "up_blocks.0" : "mid_block" ,
64+ "up_blocks.1" : "up_blocks.0.upsamplers.0" ,
65+ "up_blocks.2" : "up_blocks.0" ,
66+ "up_blocks.3" : "up_blocks.1.upsamplers.0" ,
67+ "up_blocks.4" : "up_blocks.1" ,
68+ "up_blocks.5" : "up_blocks.2.upsamplers.0" ,
69+ "up_blocks.6" : "up_blocks.2" ,
70+ "up_blocks.7" : "up_blocks.3.upsamplers.0" ,
71+ "up_blocks.8" : "up_blocks.3" ,
72+ # common
73+ "per_channel_scale1" : "scale1" ,
74+ "per_channel_scale2" : "scale2" ,
75+ "last_time_embedder" : "time_embedder" ,
76+ "last_scale_shift_table" : "scale_shift_table" ,
77+ }
78+
5779VAE_SPECIAL_KEYS_REMAP = {
5880 "per_channel_statistics.channel" : remove_keys_ ,
5981 "per_channel_statistics.mean-of-means" : remove_keys_ ,
6082 "per_channel_statistics.mean-of-stds" : remove_keys_ ,
83+ "model.diffusion_model" : remove_keys_ ,
84+ }
85+
86+ VAE_091_SPECIAL_KEYS_REMAP = {
87+ "timestep_scale_multiplier" : remove_keys_ ,
6188}
6289
6390
@@ -80,13 +107,16 @@ def convert_transformer(
80107 ckpt_path : str ,
81108 dtype : torch .dtype ,
82109):
83- PREFIX_KEY = ""
110+ PREFIX_KEY = "model.diffusion_model. "
84111
85112 original_state_dict = get_state_dict (load_file (ckpt_path ))
86- transformer = LTXVideoTransformer3DModel ().to (dtype = dtype )
113+ with init_empty_weights ():
114+ transformer = LTXVideoTransformer3DModel ()
87115
88116 for key in list (original_state_dict .keys ()):
89- new_key = key [len (PREFIX_KEY ) :]
117+ new_key = key [:]
118+ if new_key .startswith (PREFIX_KEY ):
119+ new_key = key [len (PREFIX_KEY ) :]
90120 for replace_key , rename_key in TRANSFORMER_KEYS_RENAME_DICT .items ():
91121 new_key = new_key .replace (replace_key , rename_key )
92122 update_state_dict_inplace (original_state_dict , key , new_key )
@@ -97,16 +127,21 @@ def convert_transformer(
97127 continue
98128 handler_fn_inplace (key , original_state_dict )
99129
100- transformer .load_state_dict (original_state_dict , strict = True )
130+ transformer .load_state_dict (original_state_dict , strict = True , assign = True )
101131 return transformer
102132
103133
104- def convert_vae (ckpt_path : str , dtype : torch .dtype ):
134+ def convert_vae (ckpt_path : str , config , dtype : torch .dtype ):
135+ PREFIX_KEY = "vae."
136+
105137 original_state_dict = get_state_dict (load_file (ckpt_path ))
106- vae = AutoencoderKLLTXVideo ().to (dtype = dtype )
138+ with init_empty_weights ():
139+ vae = AutoencoderKLLTXVideo (** config )
107140
108141 for key in list (original_state_dict .keys ()):
109142 new_key = key [:]
143+ if new_key .startswith (PREFIX_KEY ):
144+ new_key = key [len (PREFIX_KEY ) :]
110145 for replace_key , rename_key in VAE_KEYS_RENAME_DICT .items ():
111146 new_key = new_key .replace (replace_key , rename_key )
112147 update_state_dict_inplace (original_state_dict , key , new_key )
@@ -117,9 +152,107 @@ def convert_vae(ckpt_path: str, dtype: torch.dtype):
117152 continue
118153 handler_fn_inplace (key , original_state_dict )
119154
120- vae .load_state_dict (original_state_dict , strict = True )
155+ vae .load_state_dict (original_state_dict , strict = True , assign = True )
121156 return vae
122157
158+ # OURS_VAE_CONFIG = {
159+ # "_class_name": "CausalVideoAutoencoder",
160+ # "dims": 3,
161+ # "in_channels": 3,
162+ # "out_channels": 3,
163+ # "latent_channels": 128,
164+ # "blocks": [
165+ # ["res_x", 4],
166+ # ["compress_all", 1],
167+ # ["res_x_y", 1],
168+ # ["res_x", 3],
169+ # ["compress_all", 1],
170+ # ["res_x_y", 1],
171+ # ["res_x", 3],
172+ # ["compress_all", 1],
173+ # ["res_x", 3],
174+ # ["res_x", 4],
175+ # ],
176+ # "scaling_factor": 1.0,
177+ # "norm_layer": "pixel_norm",
178+ # "patch_size": 4,
179+ # "latent_log_var": "uniform",
180+ # "use_quant_conv": False,
181+ # "causal_decoder": False,
182+ # }
183+
184+ # {
185+ # "_class_name": "CausalVideoAutoencoder",
186+ # "dims": 3, "in_channels": 3, "out_channels": 3, "latent_channels": 128,
187+ # "encoder_blocks": [["res_x", {"num_layers": 4}], ["compress_all", {}], ["res_x_y", 1], ["res_x", {"num_layers": 3}], ["compress_all", {}], ["res_x_y", 1], ["res_x", {"num_layers": 3}], ["compress_all", {}], ["res_x", {"num_layers": 3}], ["res_x", {"num_layers": 4}]],
188+
189+ # previous decoder
190+ # mid: resx
191+ # resx
192+ # compress_all, resx
193+ # resxy, compress_all, resx
194+ # resxy, compress_all, resx
195+
196+ # "decoder_blocks": [["res_x", {"num_layers": 5, "inject_noise": true}], ["compress_all", {"residual": true, "multiplier": 2}], ["res_x", {"num_layers": 6, "inject_noise": true}], ["compress_all", {"residual": true, "multiplier": 2}], ["res_x", {"num_layers": 7, "inject_noise": true}], ["compress_all", {"residual": true, "multiplier": 2}], ["res_x", {"num_layers": 8, "inject_noise": false}]],
197+
198+ # current decoder
199+ # mid: resx
200+ # compress_all, resx
201+ # compress_all, resx
202+ # compress_all, resx
203+
204+ # "scaling_factor": 1.0, "norm_layer": "pixel_norm", "patch_size": 4, "latent_log_var": "uniform", "use_quant_conv": false, "causal_decoder": false, "timestep_conditioning": true
205+ # }
206+
207+ def get_vae_config (version : str ) -> Dict [str , Any ]:
208+ if version == "0.9.0" :
209+ config = {
210+ "in_channels" : 3 ,
211+ "out_channels" : 3 ,
212+ "latent_channels" : 128 ,
213+ "block_out_channels" : (128 , 256 , 512 , 512 ),
214+ "decoder_block_out_channels" : (128 , 256 , 512 , 512 ),
215+ "layers_per_block" : (4 , 3 , 3 , 3 , 4 ),
216+ "decoder_layers_per_block" : (4 , 3 , 3 , 3 , 4 ),
217+ "spatio_temporal_scaling" : (True , True , True , False ),
218+ "decoder_spatio_temporal_scaling" : (True , True , True , False ),
219+ "decoder_inject_noise" : (False , False , False , False ),
220+ "upsample_residual" : (False , False , False , False ),
221+ "upsample_factor" : (1 , 1 , 1 , 1 ),
222+ "patch_size" : 4 ,
223+ "patch_size_t" : 1 ,
224+ "resnet_norm_eps" : 1e-6 ,
225+ "scaling_factor" : 1.0 ,
226+ "encoder_causal" : True ,
227+ "decoder_causal" : False ,
228+ "timestep_conditioning" : False ,
229+ }
230+ elif version == "0.9.1" :
231+ config = {
232+ "in_channels" : 3 ,
233+ "out_channels" : 3 ,
234+ "latent_channels" : 128 ,
235+ "block_out_channels" : (128 , 256 , 512 , 512 ),
236+ "decoder_block_out_channels" : (256 , 512 , 1024 ),
237+ "layers_per_block" : (4 , 3 , 3 , 3 , 4 ),
238+ "decoder_layers_per_block" : (5 , 6 , 7 , 8 ),
239+ "spatio_temporal_scaling" : (True , True , True , False ),
240+ "decoder_spatio_temporal_scaling" : (True , True , True ),
241+ "decoder_inject_noise" : (False , True , True , True ),
242+ "upsample_residual" : (True , True , True ),
243+ "upsample_factor" : (2 , 2 , 2 ),
244+ "timestep_conditioning" : True ,
245+ "patch_size" : 4 ,
246+ "patch_size_t" : 1 ,
247+ "resnet_norm_eps" : 1e-6 ,
248+ "scaling_factor" : 1.0 ,
249+ "encoder_causal" : True ,
250+ "decoder_causal" : False ,
251+ }
252+ VAE_KEYS_RENAME_DICT .update (VAE_091_RENAME_DICT )
253+ VAE_SPECIAL_KEYS_REMAP .update (VAE_091_SPECIAL_KEYS_REMAP )
254+ return config
255+
123256
124257def get_args ():
125258 parser = argparse .ArgumentParser ()
@@ -139,6 +272,7 @@ def get_args():
139272 parser .add_argument ("--save_pipeline" , action = "store_true" )
140273 parser .add_argument ("--output_path" , type = str , required = True , help = "Path where converted model should be saved" )
141274 parser .add_argument ("--dtype" , default = "fp32" , help = "Torch dtype to save the model in." )
275+ parser .add_argument ("--version" , type = str , default = "0.9.0" , choices = ["0.9.0" , "0.9.1" ], help = "Version of the LTX model" )
142276 return parser .parse_args ()
143277
144278
@@ -161,6 +295,7 @@ def get_args():
161295 transformer = None
162296 dtype = DTYPE_MAPPING [args .dtype ]
163297 variant = VARIANT_MAPPING [args .dtype ]
298+ output_path = Path (args .output_path )
164299
165300 if args .save_pipeline :
166301 assert args .transformer_ckpt_path is not None and args .vae_ckpt_path is not None
@@ -169,13 +304,14 @@ def get_args():
169304 transformer : LTXVideoTransformer3DModel = convert_transformer (args .transformer_ckpt_path , dtype )
170305 if not args .save_pipeline :
171306 transformer .save_pretrained (
172- args . output_path , safe_serialization = True , max_shard_size = "5GB" , variant = variant
307+ output_path / "transformer" , safe_serialization = True , max_shard_size = "5GB" , variant = variant
173308 )
174309
175310 if args .vae_ckpt_path is not None :
176- vae : AutoencoderKLLTXVideo = convert_vae (args .vae_ckpt_path , dtype )
311+ config = get_vae_config (args .version )
312+ vae : AutoencoderKLLTXVideo = convert_vae (args .vae_ckpt_path , config , dtype )
177313 if not args .save_pipeline :
178- vae .save_pretrained (args . output_path , safe_serialization = True , max_shard_size = "5GB" , variant = variant )
314+ vae .save_pretrained (output_path / "vae" , safe_serialization = True , max_shard_size = "5GB" , variant = variant )
179315
180316 if args .save_pipeline :
181317 text_encoder_id = "google/t5-v1_1-xxl"
0 commit comments