9393 "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" ,
9494 ],
9595 "ltx-video" : [
96- ("patchify_proj.weight" , "transformer_blocks.27.scale_shift_table" ),
96+ (
97+ "model.diffusion_model.patchify_proj.weight" ,
98+ "model.diffusion_model.transformer_blocks.27.scale_shift_table" ,
99+ ),
97100 ],
98101}
99102
@@ -2211,14 +2214,20 @@ def swap_scale_shift(weight):
22112214def convert_ltx_transformer_checkpoint_to_diffusers (checkpoint , ** kwargs ):
22122215 converted_state_dict = {key : checkpoint .pop (key ) for key in list (checkpoint .keys ())}
22132216
2217+ def remove_keys_ (key : str , state_dict ):
2218+ state_dict .pop (key )
2219+
22142220 TRANSFORMER_KEYS_RENAME_DICT = {
2221+ "model.diffusion_model." : "" ,
22152222 "patchify_proj" : "proj_in" ,
22162223 "adaln_single" : "time_embed" ,
22172224 "q_norm" : "norm_q" ,
22182225 "k_norm" : "norm_k" ,
22192226 }
22202227
2221- TRANSFORMER_SPECIAL_KEYS_REMAP = {}
2228+ TRANSFORMER_SPECIAL_KEYS_REMAP = {
2229+ "vae" : remove_keys_ ,
2230+ }
22222231
22232232 for key in list (converted_state_dict .keys ()):
22242233 new_key = key
@@ -2242,6 +2251,8 @@ def remove_keys_(key: str, state_dict):
22422251 state_dict .pop (key )
22432252
22442253 VAE_KEYS_RENAME_DICT = {
2254+ # common
2255+ "vae." : "" ,
22452256 # decoder
22462257 "up_blocks.0" : "mid_block" ,
22472258 "up_blocks.1" : "up_blocks.0" ,
@@ -2276,6 +2287,7 @@ def remove_keys_(key: str, state_dict):
22762287 "per_channel_statistics.channel" : remove_keys_ ,
22772288 "per_channel_statistics.mean-of-means" : remove_keys_ ,
22782289 "per_channel_statistics.mean-of-stds" : remove_keys_ ,
2290+ "model.diffusion_model" : remove_keys_ ,
22792291 }
22802292
22812293 for key in list (converted_state_dict .keys ()):
0 commit comments