Skip to content

Commit 9ba6a06

Browse files
authored
[Single File] LTX support for loading original weights (#10135)
* from original file mixin for ltx * undo config mapping fn changes * update
1 parent 336ba36 commit 9ba6a06

File tree

4 files changed

+120
-2
lines changed

4 files changed

+120
-2
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
convert_flux_transformer_checkpoint_to_diffusers,
2828
convert_ldm_unet_checkpoint,
2929
convert_ldm_vae_checkpoint,
30+
convert_ltx_transformer_checkpoint_to_diffusers,
31+
convert_ltx_vae_checkpoint_to_diffusers,
3032
convert_sd3_transformer_checkpoint_to_diffusers,
3133
convert_stable_cascade_unet_single_file_to_diffusers,
3234
create_controlnet_diffusers_config_from_ldm,
@@ -82,6 +84,14 @@
8284
"checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
8385
"default_subfolder": "transformer",
8486
},
87+
"LTXTransformer3DModel": {
88+
"checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers,
89+
"default_subfolder": "transformer",
90+
},
91+
"AutoencoderKLLTX": {
92+
"checkpoint_mapping_fn": convert_ltx_vae_checkpoint_to_diffusers,
93+
"default_subfolder": "vae",
94+
},
8595
}
8696

8797

@@ -270,6 +280,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
270280
subfolder=subfolder,
271281
local_files_only=local_files_only,
272282
token=token,
283+
revision=revision,
273284
)
274285
expected_kwargs, optional_kwargs = cls._get_signature_keys(cls)
275286

src/diffusers/loaders/single_file_utils.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,12 @@
9292
"double_blocks.0.img_attn.norm.key_norm.scale",
9393
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
9494
],
95+
"ltx-video": [
96+
(
97+
"model.diffusion_model.patchify_proj.weight",
98+
"model.diffusion_model.transformer_blocks.27.scale_shift_table",
99+
),
100+
],
95101
}
96102

97103
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -138,6 +144,7 @@
138144
"animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"},
139145
"flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
140146
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
147+
"ltx-video": {"pretrained_model_name_or_path": "Lightricks/LTX-Video"},
141148
}
142149

143150
# Use to configure model sample size when original config is provided
@@ -564,6 +571,10 @@ def infer_diffusers_model_type(checkpoint):
564571
model_type = "flux-dev"
565572
else:
566573
model_type = "flux-schnell"
574+
575+
elif any(all(key in checkpoint for key in key_list) for key_list in CHECKPOINT_KEY_NAMES["ltx-video"]):
576+
model_type = "ltx-video"
577+
567578
else:
568579
model_type = "v1"
569580

@@ -2198,3 +2209,97 @@ def swap_scale_shift(weight):
21982209
)
21992210

22002211
return converted_state_dict
2212+
2213+
2214+
def convert_ltx_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
2215+
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
2216+
2217+
def remove_keys_(key: str, state_dict):
2218+
state_dict.pop(key)
2219+
2220+
TRANSFORMER_KEYS_RENAME_DICT = {
2221+
"model.diffusion_model.": "",
2222+
"patchify_proj": "proj_in",
2223+
"adaln_single": "time_embed",
2224+
"q_norm": "norm_q",
2225+
"k_norm": "norm_k",
2226+
}
2227+
2228+
TRANSFORMER_SPECIAL_KEYS_REMAP = {
2229+
"vae": remove_keys_,
2230+
}
2231+
2232+
for key in list(converted_state_dict.keys()):
2233+
new_key = key
2234+
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
2235+
new_key = new_key.replace(replace_key, rename_key)
2236+
converted_state_dict[new_key] = converted_state_dict.pop(key)
2237+
2238+
for key in list(converted_state_dict.keys()):
2239+
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
2240+
if special_key not in key:
2241+
continue
2242+
handler_fn_inplace(key, converted_state_dict)
2243+
2244+
return converted_state_dict
2245+
2246+
2247+
def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs):
2248+
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
2249+
2250+
def remove_keys_(key: str, state_dict):
2251+
state_dict.pop(key)
2252+
2253+
VAE_KEYS_RENAME_DICT = {
2254+
# common
2255+
"vae.": "",
2256+
# decoder
2257+
"up_blocks.0": "mid_block",
2258+
"up_blocks.1": "up_blocks.0",
2259+
"up_blocks.2": "up_blocks.1.upsamplers.0",
2260+
"up_blocks.3": "up_blocks.1",
2261+
"up_blocks.4": "up_blocks.2.conv_in",
2262+
"up_blocks.5": "up_blocks.2.upsamplers.0",
2263+
"up_blocks.6": "up_blocks.2",
2264+
"up_blocks.7": "up_blocks.3.conv_in",
2265+
"up_blocks.8": "up_blocks.3.upsamplers.0",
2266+
"up_blocks.9": "up_blocks.3",
2267+
# encoder
2268+
"down_blocks.0": "down_blocks.0",
2269+
"down_blocks.1": "down_blocks.0.downsamplers.0",
2270+
"down_blocks.2": "down_blocks.0.conv_out",
2271+
"down_blocks.3": "down_blocks.1",
2272+
"down_blocks.4": "down_blocks.1.downsamplers.0",
2273+
"down_blocks.5": "down_blocks.1.conv_out",
2274+
"down_blocks.6": "down_blocks.2",
2275+
"down_blocks.7": "down_blocks.2.downsamplers.0",
2276+
"down_blocks.8": "down_blocks.3",
2277+
"down_blocks.9": "mid_block",
2278+
# common
2279+
"conv_shortcut": "conv_shortcut.conv",
2280+
"res_blocks": "resnets",
2281+
"norm3.norm": "norm3",
2282+
"per_channel_statistics.mean-of-means": "latents_mean",
2283+
"per_channel_statistics.std-of-means": "latents_std",
2284+
}
2285+
2286+
VAE_SPECIAL_KEYS_REMAP = {
2287+
"per_channel_statistics.channel": remove_keys_,
2288+
"per_channel_statistics.mean-of-means": remove_keys_,
2289+
"per_channel_statistics.mean-of-stds": remove_keys_,
2290+
"model.diffusion_model": remove_keys_,
2291+
}
2292+
2293+
for key in list(converted_state_dict.keys()):
2294+
new_key = key
2295+
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
2296+
new_key = new_key.replace(replace_key, rename_key)
2297+
converted_state_dict[new_key] = converted_state_dict.pop(key)
2298+
2299+
for key in list(converted_state_dict.keys()):
2300+
for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
2301+
if special_key not in key:
2302+
continue
2303+
handler_fn_inplace(key, converted_state_dict)
2304+
2305+
return converted_state_dict

src/diffusers/models/autoencoders/autoencoder_kl_ltx.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch.nn as nn
2020

2121
from ...configuration_utils import ConfigMixin, register_to_config
22+
from ...loaders import FromOriginalModelMixin
2223
from ...utils.accelerate_utils import apply_forward_hook
2324
from ..activations import get_activation
2425
from ..modeling_outputs import AutoencoderKLOutput
@@ -718,7 +719,7 @@ def create_forward(*inputs):
718719
return hidden_states
719720

720721

721-
class AutoencoderKLLTX(ModelMixin, ConfigMixin):
722+
class AutoencoderKLLTX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
722723
r"""
723724
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
724725
[LTX](https://huggingface.co/Lightricks/LTX-Video).

src/diffusers/models/transformers/transformer_ltx.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import torch.nn.functional as F
2222

2323
from ...configuration_utils import ConfigMixin, register_to_config
24+
from ...loaders import FromOriginalModelMixin
2425
from ...utils import is_torch_version, logging
2526
from ...utils.torch_utils import maybe_allow_in_graph
2627
from ..attention import FeedForward
@@ -266,7 +267,7 @@ def forward(
266267

267268

268269
@maybe_allow_in_graph
269-
class LTXTransformer3DModel(ModelMixin, ConfigMixin):
270+
class LTXTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
270271
r"""
271272
A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
272273

0 commit comments

Comments
 (0)