|  | 
| 92 | 92 |         "double_blocks.0.img_attn.norm.key_norm.scale", | 
| 93 | 93 |         "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale", | 
| 94 | 94 |     ], | 
|  | 95 | +    "ltx-video": [ | 
|  | 96 | +        ( | 
|  | 97 | +            "model.diffusion_model.patchify_proj.weight", | 
|  | 98 | +            "model.diffusion_model.transformer_blocks.27.scale_shift_table", | 
|  | 99 | +        ), | 
|  | 100 | +    ], | 
| 95 | 101 | } | 
| 96 | 102 | 
 | 
| 97 | 103 | DIFFUSERS_DEFAULT_PIPELINE_PATHS = { | 
|  | 
| 138 | 144 |     "animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"}, | 
| 139 | 145 |     "flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"}, | 
| 140 | 146 |     "flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"}, | 
|  | 147 | +    "ltx-video": {"pretrained_model_name_or_path": "Lightricks/LTX-Video"}, | 
| 141 | 148 | } | 
| 142 | 149 | 
 | 
| 143 | 150 | # Use to configure model sample size when original config is provided | 
| @@ -564,6 +571,10 @@ def infer_diffusers_model_type(checkpoint): | 
| 564 | 571 |             model_type = "flux-dev" | 
| 565 | 572 |         else: | 
| 566 | 573 |             model_type = "flux-schnell" | 
|  | 574 | + | 
|  | 575 | +    elif any(all(key in checkpoint for key in key_list) for key_list in CHECKPOINT_KEY_NAMES["ltx-video"]): | 
|  | 576 | +        model_type = "ltx-video" | 
|  | 577 | + | 
| 567 | 578 |     else: | 
| 568 | 579 |         model_type = "v1" | 
| 569 | 580 | 
 | 
| @@ -2198,3 +2209,97 @@ def swap_scale_shift(weight): | 
| 2198 | 2209 |     ) | 
| 2199 | 2210 | 
 | 
| 2200 | 2211 |     return converted_state_dict | 
|  | 2212 | + | 
|  | 2213 | + | 
|  | 2214 | +def convert_ltx_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): | 
|  | 2215 | +    converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())} | 
|  | 2216 | + | 
|  | 2217 | +    def remove_keys_(key: str, state_dict): | 
|  | 2218 | +        state_dict.pop(key) | 
|  | 2219 | + | 
|  | 2220 | +    TRANSFORMER_KEYS_RENAME_DICT = { | 
|  | 2221 | +        "model.diffusion_model.": "", | 
|  | 2222 | +        "patchify_proj": "proj_in", | 
|  | 2223 | +        "adaln_single": "time_embed", | 
|  | 2224 | +        "q_norm": "norm_q", | 
|  | 2225 | +        "k_norm": "norm_k", | 
|  | 2226 | +    } | 
|  | 2227 | + | 
|  | 2228 | +    TRANSFORMER_SPECIAL_KEYS_REMAP = { | 
|  | 2229 | +        "vae": remove_keys_, | 
|  | 2230 | +    } | 
|  | 2231 | + | 
|  | 2232 | +    for key in list(converted_state_dict.keys()): | 
|  | 2233 | +        new_key = key | 
|  | 2234 | +        for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): | 
|  | 2235 | +            new_key = new_key.replace(replace_key, rename_key) | 
|  | 2236 | +        converted_state_dict[new_key] = converted_state_dict.pop(key) | 
|  | 2237 | + | 
|  | 2238 | +    for key in list(converted_state_dict.keys()): | 
|  | 2239 | +        for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): | 
|  | 2240 | +            if special_key not in key: | 
|  | 2241 | +                continue | 
|  | 2242 | +            handler_fn_inplace(key, converted_state_dict) | 
|  | 2243 | + | 
|  | 2244 | +    return converted_state_dict | 
|  | 2245 | + | 
|  | 2246 | + | 
|  | 2247 | +def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs): | 
|  | 2248 | +    converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())} | 
|  | 2249 | + | 
|  | 2250 | +    def remove_keys_(key: str, state_dict): | 
|  | 2251 | +        state_dict.pop(key) | 
|  | 2252 | + | 
|  | 2253 | +    VAE_KEYS_RENAME_DICT = { | 
|  | 2254 | +        # common | 
|  | 2255 | +        "vae.": "", | 
|  | 2256 | +        # decoder | 
|  | 2257 | +        "up_blocks.0": "mid_block", | 
|  | 2258 | +        "up_blocks.1": "up_blocks.0", | 
|  | 2259 | +        "up_blocks.2": "up_blocks.1.upsamplers.0", | 
|  | 2260 | +        "up_blocks.3": "up_blocks.1", | 
|  | 2261 | +        "up_blocks.4": "up_blocks.2.conv_in", | 
|  | 2262 | +        "up_blocks.5": "up_blocks.2.upsamplers.0", | 
|  | 2263 | +        "up_blocks.6": "up_blocks.2", | 
|  | 2264 | +        "up_blocks.7": "up_blocks.3.conv_in", | 
|  | 2265 | +        "up_blocks.8": "up_blocks.3.upsamplers.0", | 
|  | 2266 | +        "up_blocks.9": "up_blocks.3", | 
|  | 2267 | +        # encoder | 
|  | 2268 | +        "down_blocks.0": "down_blocks.0", | 
|  | 2269 | +        "down_blocks.1": "down_blocks.0.downsamplers.0", | 
|  | 2270 | +        "down_blocks.2": "down_blocks.0.conv_out", | 
|  | 2271 | +        "down_blocks.3": "down_blocks.1", | 
|  | 2272 | +        "down_blocks.4": "down_blocks.1.downsamplers.0", | 
|  | 2273 | +        "down_blocks.5": "down_blocks.1.conv_out", | 
|  | 2274 | +        "down_blocks.6": "down_blocks.2", | 
|  | 2275 | +        "down_blocks.7": "down_blocks.2.downsamplers.0", | 
|  | 2276 | +        "down_blocks.8": "down_blocks.3", | 
|  | 2277 | +        "down_blocks.9": "mid_block", | 
|  | 2278 | +        # common | 
|  | 2279 | +        "conv_shortcut": "conv_shortcut.conv", | 
|  | 2280 | +        "res_blocks": "resnets", | 
|  | 2281 | +        "norm3.norm": "norm3", | 
|  | 2282 | +        "per_channel_statistics.mean-of-means": "latents_mean", | 
|  | 2283 | +        "per_channel_statistics.std-of-means": "latents_std", | 
|  | 2284 | +    } | 
|  | 2285 | + | 
|  | 2286 | +    VAE_SPECIAL_KEYS_REMAP = { | 
|  | 2287 | +        "per_channel_statistics.channel": remove_keys_, | 
|  | 2288 | +        "per_channel_statistics.mean-of-means": remove_keys_, | 
|  | 2289 | +        "per_channel_statistics.mean-of-stds": remove_keys_, | 
|  | 2290 | +        "model.diffusion_model": remove_keys_, | 
|  | 2291 | +    } | 
|  | 2292 | + | 
|  | 2293 | +    for key in list(converted_state_dict.keys()): | 
|  | 2294 | +        new_key = key | 
|  | 2295 | +        for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): | 
|  | 2296 | +            new_key = new_key.replace(replace_key, rename_key) | 
|  | 2297 | +        converted_state_dict[new_key] = converted_state_dict.pop(key) | 
|  | 2298 | + | 
|  | 2299 | +    for key in list(converted_state_dict.keys()): | 
|  | 2300 | +        for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items(): | 
|  | 2301 | +            if special_key not in key: | 
|  | 2302 | +                continue | 
|  | 2303 | +            handler_fn_inplace(key, converted_state_dict) | 
|  | 2304 | + | 
|  | 2305 | +    return converted_state_dict | 
0 commit comments