diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index d207fafe6c5a..556f03f7992f 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -177,6 +177,7 @@ "flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"}, "ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"}, "ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"}, + "ltx-video-0.9.5": {"pretrained_model_name_or_path": "Lightricks/LTX-Video-0.9.5"}, "autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"}, "autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"}, "autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"}, @@ -638,7 +639,9 @@ def infer_diffusers_model_type(checkpoint): model_type = "flux-schnell" elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx-video"]): - if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in checkpoint: + if checkpoint["vae.encoder.conv_out.conv.weight"].shape[1] == 2048: + model_type = "ltx-video-0.9.5" + elif "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in checkpoint: model_type = "ltx-video-0.9.1" else: model_type = "ltx-video" @@ -2403,13 +2406,41 @@ def remove_keys_(key: str, state_dict): "last_scale_shift_table": "scale_shift_table", } + VAE_095_RENAME_DICT = { + # decoder + "up_blocks.0": "mid_block", + "up_blocks.1": "up_blocks.0.upsamplers.0", + "up_blocks.2": "up_blocks.0", + "up_blocks.3": "up_blocks.1.upsamplers.0", + "up_blocks.4": "up_blocks.1", + "up_blocks.5": "up_blocks.2.upsamplers.0", + "up_blocks.6": "up_blocks.2", + "up_blocks.7": "up_blocks.3.upsamplers.0", + "up_blocks.8": "up_blocks.3", + # encoder + "down_blocks.0": "down_blocks.0", + "down_blocks.1": "down_blocks.0.downsamplers.0", + "down_blocks.2": "down_blocks.1", + "down_blocks.3": "down_blocks.1.downsamplers.0", + "down_blocks.4": "down_blocks.2", + "down_blocks.5": "down_blocks.2.downsamplers.0", + "down_blocks.6": "down_blocks.3", + "down_blocks.7": "down_blocks.3.downsamplers.0", + "down_blocks.8": "mid_block", + # common + "last_time_embedder": "time_embedder", + "last_scale_shift_table": "scale_shift_table", + } + VAE_SPECIAL_KEYS_REMAP = { "per_channel_statistics.channel": remove_keys_, "per_channel_statistics.mean-of-means": remove_keys_, "per_channel_statistics.mean-of-stds": remove_keys_, } - if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict: + if converted_state_dict["vae.encoder.conv_out.conv.weight"].shape[1] == 2048: + VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT) + elif "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict: VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT) for key in list(converted_state_dict.keys()):