|
12 | 12 | from transformers import AutoModelForCausalLM, AutoTokenizer |
13 | 13 |
|
14 | 14 | from diffusers import ( |
| 15 | + AutoencoderKLLTX2Video, |
15 | 16 | AutoencoderKLWan, |
16 | 17 | DPMSolverMultistepScheduler, |
17 | 18 | FlowMatchEulerDiscreteScheduler, |
|
24 | 25 |
|
25 | 26 | CTX = init_empty_weights if is_accelerate_available else nullcontext |
26 | 27 |
|
27 | | -ckpt_ids = ["Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth"] |
| 28 | +ckpt_ids = [ |
| 29 | + "Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth", |
| 30 | + "Efficient-Large-Model/SANA-Video_2B_720p/checkpoints/SANA_Video_2B_720p_LTXVAE.pth", |
| 31 | +] |
28 | 32 | # https://github.com/NVlabs/Sana/blob/main/inference_video_scripts/inference_sana_video.py |
29 | 33 |
|
30 | 34 |
|
@@ -92,12 +96,22 @@ def main(args): |
92 | 96 | if args.video_size == 480: |
93 | 97 | sample_size = 30 # Wan-VAE: 8xp2 downsample factor |
94 | 98 | patch_size = (1, 2, 2) |
| 99 | + in_channels = 16 |
| 100 | + out_channels = 16 |
95 | 101 | elif args.video_size == 720: |
96 | | - sample_size = 22 # Wan-VAE: 32xp1 downsample factor |
| 102 | + sample_size = 22 # DC-AE-V: 32xp1 downsample factor |
97 | 103 | patch_size = (1, 1, 1) |
| 104 | + in_channels = 32 |
| 105 | + out_channels = 32 |
98 | 106 | else: |
99 | 107 | raise ValueError(f"Video size {args.video_size} is not supported.") |
100 | 108 |
|
| 109 | + if args.vae_type == "ltx2": |
| 110 | + sample_size = 22 |
| 111 | + patch_size = (1, 1, 1) |
| 112 | + in_channels = 128 |
| 113 | + out_channels = 128 |
| 114 | + |
101 | 115 | for depth in range(layer_num): |
102 | 116 | # Transformer blocks. |
103 | 117 | converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop( |
@@ -182,8 +196,8 @@ def main(args): |
182 | 196 | # Transformer |
183 | 197 | with CTX(): |
184 | 198 | transformer_kwargs = { |
185 | | - "in_channels": 16, |
186 | | - "out_channels": 16, |
| 199 | + "in_channels": in_channels, |
| 200 | + "out_channels": out_channels, |
187 | 201 | "num_attention_heads": 20, |
188 | 202 | "attention_head_dim": 112, |
189 | 203 | "num_layers": 20, |
@@ -235,9 +249,12 @@ def main(args): |
235 | 249 | else: |
236 | 250 | print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"])) |
237 | 251 | # VAE |
238 | | - vae = AutoencoderKLWan.from_pretrained( |
239 | | - "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32 |
240 | | - ) |
| 252 | + if args.vae_type == "ltx2": |
| 253 | + vae_path = args.vae_path or "Lightricks/LTX-2" |
| 254 | + vae = AutoencoderKLLTX2Video.from_pretrained(vae_path, subfolder="vae", torch_dtype=torch.float32) |
| 255 | + else: |
| 256 | + vae_path = args.vae_path or "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" |
| 257 | + vae = AutoencoderKLWan.from_pretrained(vae_path, subfolder="vae", torch_dtype=torch.float32) |
241 | 258 |
|
242 | 259 | # Text Encoder |
243 | 260 | text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it" |
@@ -314,7 +331,23 @@ def main(args): |
314 | 331 | choices=["flow-dpm_solver", "flow-euler", "uni-pc"], |
315 | 332 | help="Scheduler type to use.", |
316 | 333 | ) |
317 | | - parser.add_argument("--task", default="t2v", type=str, required=True, help="Task to convert, t2v or i2v.") |
| 334 | + parser.add_argument( |
| 335 | + "--vae_type", |
| 336 | + default="wan", |
| 337 | + type=str, |
| 338 | + choices=["wan", "ltx2"], |
| 339 | + help="VAE type to use for saving full pipeline (ltx2 uses patchify 1x1x1).", |
| 340 | + ) |
| 341 | + parser.add_argument( |
| 342 | + "--vae_path", |
| 343 | + default=None, |
| 344 | + type=str, |
| 345 | + required=False, |
| 346 | + help="Optional VAE path or repo id. If not set, a default is used per VAE type.", |
| 347 | + ) |
| 348 | + parser.add_argument( |
| 349 | + "--task", default="t2v", type=str, required=True, choices=["t2v", "i2v"], help="Task to convert, t2v or i2v." |
| 350 | + ) |
318 | 351 | parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.") |
319 | 352 | parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipeline elements in one.") |
320 | 353 | parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.") |
|
0 commit comments