Skip to content

Commit 4b55713

Browse files
authored
[core] LTX Video 0.9.1 (#10330)
* update * make style * update * update * update * make style * single file related changes * update * fix * update single file urls and docs * update * fix
1 parent 851dfa3 commit 4b55713

File tree

10 files changed

+642
-56
lines changed

10 files changed

+642
-56
lines changed

docs/source/en/api/pipelines/ltx_video.md

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License. -->
1414

15-
# LTX
15+
# LTX Video
1616

1717
[LTX Video](https://huggingface.co/Lightricks/LTX-Video) is the first DiT-based video generation model capable of generating high-quality videos in real-time. It produces 24 FPS videos at a 768x512 resolution faster than they can be watched. Trained on a large-scale dataset of diverse videos, the model generates high-resolution videos with realistic and varied content. We provide a model for both text-to-video as well as image + text-to-video usecases.
1818

@@ -22,14 +22,24 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
2222

2323
</Tip>
2424

25+
Available models:
26+
27+
| Model name | Recommended dtype |
28+
|:-------------:|:-----------------:|
29+
| [`LTX Video 0.9.0`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.safetensors) | `torch.bfloat16` |
30+
| [`LTX Video 0.9.1`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) | `torch.bfloat16` |
31+
32+
Note: The recommended dtype is for the transformer component. The VAE and text encoders can be either `torch.float32`, `torch.bfloat16` or `torch.float16` but the recommended dtype is `torch.bfloat16` as used in the original repository.
33+
2534
## Loading Single Files
2635

27-
Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`].
36+
Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`]. We recommend using `from_single_file` for the Lightricks series of models, as they plan to release multiple models in the future in the single file format.
2837

2938
```python
3039
import torch
3140
from diffusers import AutoencoderKLLTXVideo, LTXImageToVideoPipeline, LTXVideoTransformer3DModel
3241

42+
# `single_file_url` could also be https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.1.safetensors
3343
single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.safetensors"
3444
transformer = LTXVideoTransformer3DModel.from_single_file(
3545
single_file_url, torch_dtype=torch.bfloat16
@@ -99,6 +109,34 @@ export_to_video(video, "output_gguf_ltx.mp4", fps=24)
99109

100110
Make sure to read the [documentation on GGUF](../../quantization/gguf) to learn more about our GGUF support.
101111

112+
<!-- TODO(aryan): Update this when official weights are supported -->
113+
114+
Loading and running inference with [LTX Video 0.9.1](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) weights.
115+
116+
```python
117+
import torch
118+
from diffusers import LTXPipeline
119+
from diffusers.utils import export_to_video
120+
121+
pipe = LTXPipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.1-diffusers", torch_dtype=torch.bfloat16)
122+
pipe.to("cuda")
123+
124+
prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
125+
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
126+
127+
video = pipe(
128+
prompt=prompt,
129+
negative_prompt=negative_prompt,
130+
width=768,
131+
height=512,
132+
num_frames=161,
133+
decode_timestep=0.03,
134+
decode_noise_scale=0.025,
135+
num_inference_steps=50,
136+
).frames[0]
137+
export_to_video(video, "output.mp4", fps=24)
138+
```
139+
102140
Refer to [this section](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox#memory-optimization) to learn more about optimizing memory consumption.
103141

104142
## LTXPipeline

scripts/convert_ltx_to_diffusers.py

Lines changed: 99 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import argparse
2+
from pathlib import Path
23
from typing import Any, Dict
34

45
import torch
6+
from accelerate import init_empty_weights
57
from safetensors.torch import load_file
68
from transformers import T5EncoderModel, T5Tokenizer
79

@@ -21,7 +23,9 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]):
2123
"k_norm": "norm_k",
2224
}
2325

24-
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
26+
TRANSFORMER_SPECIAL_KEYS_REMAP = {
27+
"vae": remove_keys_,
28+
}
2529

2630
VAE_KEYS_RENAME_DICT = {
2731
# decoder
@@ -54,10 +58,31 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]):
5458
"per_channel_statistics.std-of-means": "latents_std",
5559
}
5660

61+
VAE_091_RENAME_DICT = {
62+
# decoder
63+
"up_blocks.0": "mid_block",
64+
"up_blocks.1": "up_blocks.0.upsamplers.0",
65+
"up_blocks.2": "up_blocks.0",
66+
"up_blocks.3": "up_blocks.1.upsamplers.0",
67+
"up_blocks.4": "up_blocks.1",
68+
"up_blocks.5": "up_blocks.2.upsamplers.0",
69+
"up_blocks.6": "up_blocks.2",
70+
"up_blocks.7": "up_blocks.3.upsamplers.0",
71+
"up_blocks.8": "up_blocks.3",
72+
# common
73+
"last_time_embedder": "time_embedder",
74+
"last_scale_shift_table": "scale_shift_table",
75+
}
76+
5777
VAE_SPECIAL_KEYS_REMAP = {
5878
"per_channel_statistics.channel": remove_keys_,
5979
"per_channel_statistics.mean-of-means": remove_keys_,
6080
"per_channel_statistics.mean-of-stds": remove_keys_,
81+
"model.diffusion_model": remove_keys_,
82+
}
83+
84+
VAE_091_SPECIAL_KEYS_REMAP = {
85+
"timestep_scale_multiplier": remove_keys_,
6186
}
6287

6388

@@ -80,13 +105,16 @@ def convert_transformer(
80105
ckpt_path: str,
81106
dtype: torch.dtype,
82107
):
83-
PREFIX_KEY = ""
108+
PREFIX_KEY = "model.diffusion_model."
84109

85110
original_state_dict = get_state_dict(load_file(ckpt_path))
86-
transformer = LTXVideoTransformer3DModel().to(dtype=dtype)
111+
with init_empty_weights():
112+
transformer = LTXVideoTransformer3DModel()
87113

88114
for key in list(original_state_dict.keys()):
89-
new_key = key[len(PREFIX_KEY) :]
115+
new_key = key[:]
116+
if new_key.startswith(PREFIX_KEY):
117+
new_key = key[len(PREFIX_KEY) :]
90118
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
91119
new_key = new_key.replace(replace_key, rename_key)
92120
update_state_dict_inplace(original_state_dict, key, new_key)
@@ -97,16 +125,21 @@ def convert_transformer(
97125
continue
98126
handler_fn_inplace(key, original_state_dict)
99127

100-
transformer.load_state_dict(original_state_dict, strict=True)
128+
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
101129
return transformer
102130

103131

104-
def convert_vae(ckpt_path: str, dtype: torch.dtype):
132+
def convert_vae(ckpt_path: str, config, dtype: torch.dtype):
133+
PREFIX_KEY = "vae."
134+
105135
original_state_dict = get_state_dict(load_file(ckpt_path))
106-
vae = AutoencoderKLLTXVideo().to(dtype=dtype)
136+
with init_empty_weights():
137+
vae = AutoencoderKLLTXVideo(**config)
107138

108139
for key in list(original_state_dict.keys()):
109140
new_key = key[:]
141+
if new_key.startswith(PREFIX_KEY):
142+
new_key = key[len(PREFIX_KEY) :]
110143
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
111144
new_key = new_key.replace(replace_key, rename_key)
112145
update_state_dict_inplace(original_state_dict, key, new_key)
@@ -117,10 +150,60 @@ def convert_vae(ckpt_path: str, dtype: torch.dtype):
117150
continue
118151
handler_fn_inplace(key, original_state_dict)
119152

120-
vae.load_state_dict(original_state_dict, strict=True)
153+
vae.load_state_dict(original_state_dict, strict=True, assign=True)
121154
return vae
122155

123156

157+
def get_vae_config(version: str) -> Dict[str, Any]:
158+
if version == "0.9.0":
159+
config = {
160+
"in_channels": 3,
161+
"out_channels": 3,
162+
"latent_channels": 128,
163+
"block_out_channels": (128, 256, 512, 512),
164+
"decoder_block_out_channels": (128, 256, 512, 512),
165+
"layers_per_block": (4, 3, 3, 3, 4),
166+
"decoder_layers_per_block": (4, 3, 3, 3, 4),
167+
"spatio_temporal_scaling": (True, True, True, False),
168+
"decoder_spatio_temporal_scaling": (True, True, True, False),
169+
"decoder_inject_noise": (False, False, False, False, False),
170+
"upsample_residual": (False, False, False, False),
171+
"upsample_factor": (1, 1, 1, 1),
172+
"patch_size": 4,
173+
"patch_size_t": 1,
174+
"resnet_norm_eps": 1e-6,
175+
"scaling_factor": 1.0,
176+
"encoder_causal": True,
177+
"decoder_causal": False,
178+
"timestep_conditioning": False,
179+
}
180+
elif version == "0.9.1":
181+
config = {
182+
"in_channels": 3,
183+
"out_channels": 3,
184+
"latent_channels": 128,
185+
"block_out_channels": (128, 256, 512, 512),
186+
"decoder_block_out_channels": (256, 512, 1024),
187+
"layers_per_block": (4, 3, 3, 3, 4),
188+
"decoder_layers_per_block": (5, 6, 7, 8),
189+
"spatio_temporal_scaling": (True, True, True, False),
190+
"decoder_spatio_temporal_scaling": (True, True, True),
191+
"decoder_inject_noise": (True, True, True, False),
192+
"upsample_residual": (True, True, True),
193+
"upsample_factor": (2, 2, 2),
194+
"timestep_conditioning": True,
195+
"patch_size": 4,
196+
"patch_size_t": 1,
197+
"resnet_norm_eps": 1e-6,
198+
"scaling_factor": 1.0,
199+
"encoder_causal": True,
200+
"decoder_causal": False,
201+
}
202+
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
203+
VAE_SPECIAL_KEYS_REMAP.update(VAE_091_SPECIAL_KEYS_REMAP)
204+
return config
205+
206+
124207
def get_args():
125208
parser = argparse.ArgumentParser()
126209
parser.add_argument(
@@ -139,6 +222,9 @@ def get_args():
139222
parser.add_argument("--save_pipeline", action="store_true")
140223
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
141224
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
225+
parser.add_argument(
226+
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1"], help="Version of the LTX model"
227+
)
142228
return parser.parse_args()
143229

144230

@@ -161,6 +247,7 @@ def get_args():
161247
transformer = None
162248
dtype = DTYPE_MAPPING[args.dtype]
163249
variant = VARIANT_MAPPING[args.dtype]
250+
output_path = Path(args.output_path)
164251

165252
if args.save_pipeline:
166253
assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None
@@ -169,13 +256,14 @@ def get_args():
169256
transformer: LTXVideoTransformer3DModel = convert_transformer(args.transformer_ckpt_path, dtype)
170257
if not args.save_pipeline:
171258
transformer.save_pretrained(
172-
args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant
259+
output_path / "transformer", safe_serialization=True, max_shard_size="5GB", variant=variant
173260
)
174261

175262
if args.vae_ckpt_path is not None:
176-
vae: AutoencoderKLLTXVideo = convert_vae(args.vae_ckpt_path, dtype)
263+
config = get_vae_config(args.version)
264+
vae: AutoencoderKLLTXVideo = convert_vae(args.vae_ckpt_path, config, dtype)
177265
if not args.save_pipeline:
178-
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant)
266+
vae.save_pretrained(output_path / "vae", safe_serialization=True, max_shard_size="5GB", variant=variant)
179267

180268
if args.save_pipeline:
181269
text_encoder_id = "google/t5-v1_1-xxl"

src/diffusers/loaders/single_file_utils.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,8 @@
157157
"flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"},
158158
"flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
159159
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
160-
"ltx-video": {"pretrained_model_name_or_path": "Lightricks/LTX-Video"},
160+
"ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"},
161+
"ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"},
161162
"autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"},
162163
"autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"},
163164
"autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
@@ -605,7 +606,10 @@ def infer_diffusers_model_type(checkpoint):
605606
model_type = "flux-schnell"
606607

607608
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx-video"]):
608-
model_type = "ltx-video"
609+
if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in checkpoint:
610+
model_type = "ltx-video-0.9.1"
611+
else:
612+
model_type = "ltx-video"
609613

610614
elif CHECKPOINT_KEY_NAMES["autoencoder-dc"] in checkpoint:
611615
encoder_key = "encoder.project_in.conv.conv.bias"
@@ -2338,12 +2342,32 @@ def remove_keys_(key: str, state_dict):
23382342
"per_channel_statistics.std-of-means": "latents_std",
23392343
}
23402344

2345+
VAE_091_RENAME_DICT = {
2346+
# decoder
2347+
"up_blocks.0": "mid_block",
2348+
"up_blocks.1": "up_blocks.0.upsamplers.0",
2349+
"up_blocks.2": "up_blocks.0",
2350+
"up_blocks.3": "up_blocks.1.upsamplers.0",
2351+
"up_blocks.4": "up_blocks.1",
2352+
"up_blocks.5": "up_blocks.2.upsamplers.0",
2353+
"up_blocks.6": "up_blocks.2",
2354+
"up_blocks.7": "up_blocks.3.upsamplers.0",
2355+
"up_blocks.8": "up_blocks.3",
2356+
# common
2357+
"last_time_embedder": "time_embedder",
2358+
"last_scale_shift_table": "scale_shift_table",
2359+
}
2360+
23412361
VAE_SPECIAL_KEYS_REMAP = {
23422362
"per_channel_statistics.channel": remove_keys_,
23432363
"per_channel_statistics.mean-of-means": remove_keys_,
23442364
"per_channel_statistics.mean-of-stds": remove_keys_,
2365+
"timestep_scale_multiplier": remove_keys_,
23452366
}
23462367

2368+
if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict:
2369+
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
2370+
23472371
for key in list(converted_state_dict.keys()):
23482372
new_key = key
23492373
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():

0 commit comments

Comments
 (0)