@@ -74,6 +74,32 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]):
7474    "last_scale_shift_table" : "scale_shift_table" ,
7575}
7676
77+ VAE_095_RENAME_DICT  =  {
78+     # decoder 
79+     "up_blocks.0" : "mid_block" ,
80+     "up_blocks.1" : "up_blocks.0.upsamplers.0" ,
81+     "up_blocks.2" : "up_blocks.0" ,
82+     "up_blocks.3" : "up_blocks.1.upsamplers.0" ,
83+     "up_blocks.4" : "up_blocks.1" ,
84+     "up_blocks.5" : "up_blocks.2.upsamplers.0" ,
85+     "up_blocks.6" : "up_blocks.2" ,
86+     "up_blocks.7" : "up_blocks.3.upsamplers.0" ,
87+     "up_blocks.8" : "up_blocks.3" ,
88+     # encoder 
89+     "down_blocks.0" : "down_blocks.0" ,
90+     "down_blocks.1" : "down_blocks.0.downsamplers.0" ,
91+     "down_blocks.2" : "down_blocks.1" ,
92+     "down_blocks.3" : "down_blocks.1.downsamplers.0" ,
93+     "down_blocks.4" : "down_blocks.2" ,
94+     "down_blocks.5" : "down_blocks.2.downsamplers.0" ,
95+     "down_blocks.6" : "down_blocks.3" ,
96+     "down_blocks.7" : "down_blocks.3.downsamplers.0" ,
97+     "down_blocks.8" : "mid_block" ,
98+     # common 
99+     "last_time_embedder" : "time_embedder" ,
100+     "last_scale_shift_table" : "scale_shift_table" ,
101+ }
102+ 
77103VAE_SPECIAL_KEYS_REMAP  =  {
78104    "per_channel_statistics.channel" : remove_keys_ ,
79105    "per_channel_statistics.mean-of-means" : remove_keys_ ,
@@ -85,6 +111,10 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]):
85111    "timestep_scale_multiplier" : remove_keys_ ,
86112}
87113
114+ VAE_095_SPECIAL_KEYS_REMAP  =  {
115+     "timestep_scale_multiplier" : remove_keys_ ,
116+ }
117+ 
88118
89119def  get_state_dict (saved_dict : Dict [str , Any ]) ->  Dict [str , Any ]:
90120    state_dict  =  saved_dict 
@@ -161,12 +191,19 @@ def get_vae_config(version: str) -> Dict[str, Any]:
161191            "out_channels" : 3 ,
162192            "latent_channels" : 128 ,
163193            "block_out_channels" : (128 , 256 , 512 , 512 ),
194+             "down_block_types" : (
195+                 "LTXVideoDownBlock3D" ,
196+                 "LTXVideoDownBlock3D" ,
197+                 "LTXVideoDownBlock3D" ,
198+                 "LTXVideoDownBlock3D" ,
199+             ),
164200            "decoder_block_out_channels" : (128 , 256 , 512 , 512 ),
165201            "layers_per_block" : (4 , 3 , 3 , 3 , 4 ),
166202            "decoder_layers_per_block" : (4 , 3 , 3 , 3 , 4 ),
167203            "spatio_temporal_scaling" : (True , True , True , False ),
168204            "decoder_spatio_temporal_scaling" : (True , True , True , False ),
169205            "decoder_inject_noise" : (False , False , False , False , False ),
206+             "downsample_type" : ("conv" , "conv" , "conv" , "conv" ),
170207            "upsample_residual" : (False , False , False , False ),
171208            "upsample_factor" : (1 , 1 , 1 , 1 ),
172209            "patch_size" : 4 ,
@@ -183,12 +220,19 @@ def get_vae_config(version: str) -> Dict[str, Any]:
183220            "out_channels" : 3 ,
184221            "latent_channels" : 128 ,
185222            "block_out_channels" : (128 , 256 , 512 , 512 ),
223+             "down_block_types" : (
224+                 "LTXVideoDownBlock3D" ,
225+                 "LTXVideoDownBlock3D" ,
226+                 "LTXVideoDownBlock3D" ,
227+                 "LTXVideoDownBlock3D" ,
228+             ),
186229            "decoder_block_out_channels" : (256 , 512 , 1024 ),
187230            "layers_per_block" : (4 , 3 , 3 , 3 , 4 ),
188231            "decoder_layers_per_block" : (5 , 6 , 7 , 8 ),
189232            "spatio_temporal_scaling" : (True , True , True , False ),
190233            "decoder_spatio_temporal_scaling" : (True , True , True ),
191234            "decoder_inject_noise" : (True , True , True , False ),
235+             "downsample_type" : ("conv" , "conv" , "conv" , "conv" ),
192236            "upsample_residual" : (True , True , True ),
193237            "upsample_factor" : (2 , 2 , 2 ),
194238            "timestep_conditioning" : True ,
@@ -201,6 +245,37 @@ def get_vae_config(version: str) -> Dict[str, Any]:
201245        }
202246        VAE_KEYS_RENAME_DICT .update (VAE_091_RENAME_DICT )
203247        VAE_SPECIAL_KEYS_REMAP .update (VAE_091_SPECIAL_KEYS_REMAP )
248+     elif  version  ==  "0.9.5" :
249+         config  =  {
250+             "in_channels" : 3 ,
251+             "out_channels" : 3 ,
252+             "latent_channels" : 128 ,
253+             "block_out_channels" : (128 , 256 , 512 , 1024 , 2048 ),
254+             "down_block_types" : (
255+                 "LTXVideo095DownBlock3D" ,
256+                 "LTXVideo095DownBlock3D" ,
257+                 "LTXVideo095DownBlock3D" ,
258+                 "LTXVideo095DownBlock3D" ,
259+             ),
260+             "decoder_block_out_channels" : (256 , 512 , 1024 ),
261+             "layers_per_block" : (4 , 6 , 6 , 2 , 2 ),
262+             "decoder_layers_per_block" : (5 , 5 , 5 , 5 ),
263+             "spatio_temporal_scaling" : (True , True , True , True ),
264+             "decoder_spatio_temporal_scaling" : (True , True , True ),
265+             "decoder_inject_noise" : (False , False , False , False ),
266+             "downsample_type" : ("spatial" , "temporal" , "spatiotemporal" , "spatiotemporal" ),
267+             "upsample_residual" : (True , True , True ),
268+             "upsample_factor" : (2 , 2 , 2 ),
269+             "timestep_conditioning" : True ,
270+             "patch_size" : 4 ,
271+             "patch_size_t" : 1 ,
272+             "resnet_norm_eps" : 1e-6 ,
273+             "scaling_factor" : 1.0 ,
274+             "encoder_causal" : True ,
275+             "decoder_causal" : False ,
276+         }
277+         VAE_KEYS_RENAME_DICT .update (VAE_095_RENAME_DICT )
278+         VAE_SPECIAL_KEYS_REMAP .update (VAE_095_SPECIAL_KEYS_REMAP )
204279    return  config 
205280
206281
@@ -223,7 +298,7 @@ def get_args():
223298    parser .add_argument ("--output_path" , type = str , required = True , help = "Path where converted model should be saved" )
224299    parser .add_argument ("--dtype" , default = "fp32" , help = "Torch dtype to save the model in." )
225300    parser .add_argument (
226-         "--version" , type = str , default = "0.9.0" , choices = ["0.9.0" , "0.9.1" ], help = "Version of the LTX model" 
301+         "--version" , type = str , default = "0.9.0" , choices = ["0.9.0" , "0.9.1" ,  "0.9.5" ], help = "Version of the LTX model" 
227302    )
228303    return  parser .parse_args ()
229304
0 commit comments