@@ -74,17 +74,39 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]):
7474 "last_scale_shift_table" : "scale_shift_table" ,
7575}
7676
77+ VAE_095_RENAME_DICT = {
78+ # decoder
79+ "up_blocks.0" : "mid_block" ,
80+ "up_blocks.1" : "up_blocks.0.upsamplers.0" ,
81+ "up_blocks.2" : "up_blocks.0" ,
82+ "up_blocks.3" : "up_blocks.1.upsamplers.0" ,
83+ "up_blocks.4" : "up_blocks.1" ,
84+ "up_blocks.5" : "up_blocks.2.upsamplers.0" ,
85+ "up_blocks.6" : "up_blocks.2" ,
86+ "up_blocks.7" : "up_blocks.3.upsamplers.0" ,
87+ "up_blocks.8" : "up_blocks.3" ,
88+ # encoder
89+ "down_blocks.0" : "down_blocks.0" ,
90+ "down_blocks.1" : "down_blocks.0.downsamplers.0" ,
91+ "down_blocks.2" : "down_blocks.1" ,
92+ "down_blocks.3" : "down_blocks.1.downsamplers.0" ,
93+ "down_blocks.4" : "down_blocks.2" ,
94+ "down_blocks.5" : "down_blocks.2.downsamplers.0" ,
95+ "down_blocks.6" : "down_blocks.3" ,
96+ "down_blocks.7" : "down_blocks.3.downsamplers.0" ,
97+ "down_blocks.8" : "mid_block" ,
98+ # common
99+ "last_time_embedder" : "time_embedder" ,
100+ "last_scale_shift_table" : "scale_shift_table" ,
101+ }
102+
77103VAE_SPECIAL_KEYS_REMAP = {
78104 "per_channel_statistics.channel" : remove_keys_ ,
79105 "per_channel_statistics.mean-of-means" : remove_keys_ ,
80106 "per_channel_statistics.mean-of-stds" : remove_keys_ ,
81107 "model.diffusion_model" : remove_keys_ ,
82108}
83109
84- VAE_091_SPECIAL_KEYS_REMAP = {
85- "timestep_scale_multiplier" : remove_keys_ ,
86- }
87-
88110
89111def get_state_dict (saved_dict : Dict [str , Any ]) -> Dict [str , Any ]:
90112 state_dict = saved_dict
@@ -104,12 +126,16 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
104126def convert_transformer (
105127 ckpt_path : str ,
106128 dtype : torch .dtype ,
129+ version : str = "0.9.0" ,
107130):
108131 PREFIX_KEY = "model.diffusion_model."
109132
110133 original_state_dict = get_state_dict (load_file (ckpt_path ))
134+ config = {}
135+ if version == "0.9.5" :
136+ config ["_use_causal_rope_fix" ] = True
111137 with init_empty_weights ():
112- transformer = LTXVideoTransformer3DModel ()
138+ transformer = LTXVideoTransformer3DModel (** config )
113139
114140 for key in list (original_state_dict .keys ()):
115141 new_key = key [:]
@@ -161,12 +187,19 @@ def get_vae_config(version: str) -> Dict[str, Any]:
161187 "out_channels" : 3 ,
162188 "latent_channels" : 128 ,
163189 "block_out_channels" : (128 , 256 , 512 , 512 ),
190+ "down_block_types" : (
191+ "LTXVideoDownBlock3D" ,
192+ "LTXVideoDownBlock3D" ,
193+ "LTXVideoDownBlock3D" ,
194+ "LTXVideoDownBlock3D" ,
195+ ),
164196 "decoder_block_out_channels" : (128 , 256 , 512 , 512 ),
165197 "layers_per_block" : (4 , 3 , 3 , 3 , 4 ),
166198 "decoder_layers_per_block" : (4 , 3 , 3 , 3 , 4 ),
167199 "spatio_temporal_scaling" : (True , True , True , False ),
168200 "decoder_spatio_temporal_scaling" : (True , True , True , False ),
169201 "decoder_inject_noise" : (False , False , False , False , False ),
202+ "downsample_type" : ("conv" , "conv" , "conv" , "conv" ),
170203 "upsample_residual" : (False , False , False , False ),
171204 "upsample_factor" : (1 , 1 , 1 , 1 ),
172205 "patch_size" : 4 ,
@@ -183,12 +216,19 @@ def get_vae_config(version: str) -> Dict[str, Any]:
183216 "out_channels" : 3 ,
184217 "latent_channels" : 128 ,
185218 "block_out_channels" : (128 , 256 , 512 , 512 ),
219+ "down_block_types" : (
220+ "LTXVideoDownBlock3D" ,
221+ "LTXVideoDownBlock3D" ,
222+ "LTXVideoDownBlock3D" ,
223+ "LTXVideoDownBlock3D" ,
224+ ),
186225 "decoder_block_out_channels" : (256 , 512 , 1024 ),
187226 "layers_per_block" : (4 , 3 , 3 , 3 , 4 ),
188227 "decoder_layers_per_block" : (5 , 6 , 7 , 8 ),
189228 "spatio_temporal_scaling" : (True , True , True , False ),
190229 "decoder_spatio_temporal_scaling" : (True , True , True ),
191230 "decoder_inject_noise" : (True , True , True , False ),
231+ "downsample_type" : ("conv" , "conv" , "conv" , "conv" ),
192232 "upsample_residual" : (True , True , True ),
193233 "upsample_factor" : (2 , 2 , 2 ),
194234 "timestep_conditioning" : True ,
@@ -200,7 +240,38 @@ def get_vae_config(version: str) -> Dict[str, Any]:
200240 "decoder_causal" : False ,
201241 }
202242 VAE_KEYS_RENAME_DICT .update (VAE_091_RENAME_DICT )
203- VAE_SPECIAL_KEYS_REMAP .update (VAE_091_SPECIAL_KEYS_REMAP )
243+ elif version == "0.9.5" :
244+ config = {
245+ "in_channels" : 3 ,
246+ "out_channels" : 3 ,
247+ "latent_channels" : 128 ,
248+ "block_out_channels" : (128 , 256 , 512 , 1024 , 2048 ),
249+ "down_block_types" : (
250+ "LTXVideo095DownBlock3D" ,
251+ "LTXVideo095DownBlock3D" ,
252+ "LTXVideo095DownBlock3D" ,
253+ "LTXVideo095DownBlock3D" ,
254+ ),
255+ "decoder_block_out_channels" : (256 , 512 , 1024 ),
256+ "layers_per_block" : (4 , 6 , 6 , 2 , 2 ),
257+ "decoder_layers_per_block" : (5 , 5 , 5 , 5 ),
258+ "spatio_temporal_scaling" : (True , True , True , True ),
259+ "decoder_spatio_temporal_scaling" : (True , True , True ),
260+ "decoder_inject_noise" : (False , False , False , False ),
261+ "downsample_type" : ("spatial" , "temporal" , "spatiotemporal" , "spatiotemporal" ),
262+ "upsample_residual" : (True , True , True ),
263+ "upsample_factor" : (2 , 2 , 2 ),
264+ "timestep_conditioning" : True ,
265+ "patch_size" : 4 ,
266+ "patch_size_t" : 1 ,
267+ "resnet_norm_eps" : 1e-6 ,
268+ "scaling_factor" : 1.0 ,
269+ "encoder_causal" : True ,
270+ "decoder_causal" : False ,
271+ "spatial_compression_ratio" : 32 ,
272+ "temporal_compression_ratio" : 8 ,
273+ }
274+ VAE_KEYS_RENAME_DICT .update (VAE_095_RENAME_DICT )
204275 return config
205276
206277
@@ -223,7 +294,7 @@ def get_args():
223294 parser .add_argument ("--output_path" , type = str , required = True , help = "Path where converted model should be saved" )
224295 parser .add_argument ("--dtype" , default = "fp32" , help = "Torch dtype to save the model in." )
225296 parser .add_argument (
226- "--version" , type = str , default = "0.9.0" , choices = ["0.9.0" , "0.9.1" ], help = "Version of the LTX model"
297+ "--version" , type = str , default = "0.9.0" , choices = ["0.9.0" , "0.9.1" , "0.9.5" ], help = "Version of the LTX model"
227298 )
228299 return parser .parse_args ()
229300
@@ -277,14 +348,17 @@ def get_args():
277348 for param in text_encoder .parameters ():
278349 param .data = param .data .contiguous ()
279350
280- scheduler = FlowMatchEulerDiscreteScheduler (
281- use_dynamic_shifting = True ,
282- base_shift = 0.95 ,
283- max_shift = 2.05 ,
284- base_image_seq_len = 1024 ,
285- max_image_seq_len = 4096 ,
286- shift_terminal = 0.1 ,
287- )
351+ if args .version == "0.9.5" :
352+ scheduler = FlowMatchEulerDiscreteScheduler (use_dynamic_shifting = False )
353+ else :
354+ scheduler = FlowMatchEulerDiscreteScheduler (
355+ use_dynamic_shifting = True ,
356+ base_shift = 0.95 ,
357+ max_shift = 2.05 ,
358+ base_image_seq_len = 1024 ,
359+ max_image_seq_len = 4096 ,
360+ shift_terminal = 0.1 ,
361+ )
288362
289363 pipe = LTXPipeline (
290364 scheduler = scheduler ,
0 commit comments