11import argparse
2- from typing import Any , Dict
32from pathlib import Path
3+ from typing import Any , Dict
44
55import torch
66from accelerate import init_empty_weights
@@ -133,7 +133,7 @@ def convert_transformer(
133133
134134def convert_vae (ckpt_path : str , config , dtype : torch .dtype ):
135135 PREFIX_KEY = "vae."
136-
136+
137137 original_state_dict = get_state_dict (load_file (ckpt_path ))
138138 with init_empty_weights ():
139139 vae = AutoencoderKLLTXVideo (** config )
@@ -155,54 +155,6 @@ def convert_vae(ckpt_path: str, config, dtype: torch.dtype):
155155 vae .load_state_dict (original_state_dict , strict = True , assign = True )
156156 return vae
157157
158- # OURS_VAE_CONFIG = {
159- # "_class_name": "CausalVideoAutoencoder",
160- # "dims": 3,
161- # "in_channels": 3,
162- # "out_channels": 3,
163- # "latent_channels": 128,
164- # "blocks": [
165- # ["res_x", 4],
166- # ["compress_all", 1],
167- # ["res_x_y", 1],
168- # ["res_x", 3],
169- # ["compress_all", 1],
170- # ["res_x_y", 1],
171- # ["res_x", 3],
172- # ["compress_all", 1],
173- # ["res_x", 3],
174- # ["res_x", 4],
175- # ],
176- # "scaling_factor": 1.0,
177- # "norm_layer": "pixel_norm",
178- # "patch_size": 4,
179- # "latent_log_var": "uniform",
180- # "use_quant_conv": False,
181- # "causal_decoder": False,
182- # }
183-
184- # {
185- # "_class_name": "CausalVideoAutoencoder",
186- # "dims": 3, "in_channels": 3, "out_channels": 3, "latent_channels": 128,
187- # "encoder_blocks": [["res_x", {"num_layers": 4}], ["compress_all", {}], ["res_x_y", 1], ["res_x", {"num_layers": 3}], ["compress_all", {}], ["res_x_y", 1], ["res_x", {"num_layers": 3}], ["compress_all", {}], ["res_x", {"num_layers": 3}], ["res_x", {"num_layers": 4}]],
188-
189- # previous decoder
190- # mid: resx
191- # resx
192- # compress_all, resx
193- # resxy, compress_all, resx
194- # resxy, compress_all, resx
195-
196- # "decoder_blocks": [["res_x", {"num_layers": 5, "inject_noise": true}], ["compress_all", {"residual": true, "multiplier": 2}], ["res_x", {"num_layers": 6, "inject_noise": true}], ["compress_all", {"residual": true, "multiplier": 2}], ["res_x", {"num_layers": 7, "inject_noise": true}], ["compress_all", {"residual": true, "multiplier": 2}], ["res_x", {"num_layers": 8, "inject_noise": false}]],
197-
198- # current decoder
199- # mid: resx
200- # compress_all, resx
201- # compress_all, resx
202- # compress_all, resx
203-
204- # "scaling_factor": 1.0, "norm_layer": "pixel_norm", "patch_size": 4, "latent_log_var": "uniform", "use_quant_conv": false, "causal_decoder": false, "timestep_conditioning": true
205- # }
206158
207159def get_vae_config (version : str ) -> Dict [str , Any ]:
208160 if version == "0.9.0" :
@@ -272,7 +224,9 @@ def get_args():
272224 parser .add_argument ("--save_pipeline" , action = "store_true" )
273225 parser .add_argument ("--output_path" , type = str , required = True , help = "Path where converted model should be saved" )
274226 parser .add_argument ("--dtype" , default = "fp32" , help = "Torch dtype to save the model in." )
275- parser .add_argument ("--version" , type = str , default = "0.9.0" , choices = ["0.9.0" , "0.9.1" ], help = "Version of the LTX model" )
227+ parser .add_argument (
228+ "--version" , type = str , default = "0.9.0" , choices = ["0.9.0" , "0.9.1" ], help = "Version of the LTX model"
229+ )
276230 return parser .parse_args ()
277231
278232
0 commit comments