Skip to content

Commit c03b739

Browse files
committed
add ltx vae in conversion script;
1 parent cd639f9 commit c03b739

File tree

1 file changed

+41
-8
lines changed

1 file changed

+41
-8
lines changed

scripts/convert_sana_video_to_diffusers.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from transformers import AutoModelForCausalLM, AutoTokenizer
1313

1414
from diffusers import (
15+
AutoencoderKLLTX2Video,
1516
AutoencoderKLWan,
1617
DPMSolverMultistepScheduler,
1718
FlowMatchEulerDiscreteScheduler,
@@ -24,7 +25,10 @@
2425

2526
CTX = init_empty_weights if is_accelerate_available else nullcontext
2627

27-
ckpt_ids = ["Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth"]
28+
ckpt_ids = [
29+
"Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth",
30+
"Efficient-Large-Model/SANA-Video_2B_720p/checkpoints/SANA_Video_2B_720p_LTXVAE.pth",
31+
]
2832
# https://github.com/NVlabs/Sana/blob/main/inference_video_scripts/inference_sana_video.py
2933

3034

@@ -92,12 +96,22 @@ def main(args):
9296
if args.video_size == 480:
9397
sample_size = 30 # Wan-VAE: 8xp2 downsample factor
9498
patch_size = (1, 2, 2)
99+
in_channels = 16
100+
out_channels = 16
95101
elif args.video_size == 720:
96-
sample_size = 22 # Wan-VAE: 32xp1 downsample factor
102+
sample_size = 22 # DC-AE-V: 32xp1 downsample factor
97103
patch_size = (1, 1, 1)
104+
in_channels = 32
105+
out_channels = 32
98106
else:
99107
raise ValueError(f"Video size {args.video_size} is not supported.")
100108

109+
if args.vae_type == "ltx2":
110+
sample_size = 22
111+
patch_size = (1, 1, 1)
112+
in_channels = 128
113+
out_channels = 128
114+
101115
for depth in range(layer_num):
102116
# Transformer blocks.
103117
converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
@@ -182,8 +196,8 @@ def main(args):
182196
# Transformer
183197
with CTX():
184198
transformer_kwargs = {
185-
"in_channels": 16,
186-
"out_channels": 16,
199+
"in_channels": in_channels,
200+
"out_channels": out_channels,
187201
"num_attention_heads": 20,
188202
"attention_head_dim": 112,
189203
"num_layers": 20,
@@ -235,9 +249,12 @@ def main(args):
235249
else:
236250
print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"]))
237251
# VAE
238-
vae = AutoencoderKLWan.from_pretrained(
239-
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
240-
)
252+
if args.vae_type == "ltx2":
253+
vae_path = args.vae_path or "Lightricks/LTX-2"
254+
vae = AutoencoderKLLTX2Video.from_pretrained(vae_path, subfolder="vae", torch_dtype=torch.float32)
255+
else:
256+
vae_path = args.vae_path or "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
257+
vae = AutoencoderKLWan.from_pretrained(vae_path, subfolder="vae", torch_dtype=torch.float32)
241258

242259
# Text Encoder
243260
text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it"
@@ -314,7 +331,23 @@ def main(args):
314331
choices=["flow-dpm_solver", "flow-euler", "uni-pc"],
315332
help="Scheduler type to use.",
316333
)
317-
parser.add_argument("--task", default="t2v", type=str, required=True, help="Task to convert, t2v or i2v.")
334+
parser.add_argument(
335+
"--vae_type",
336+
default="wan",
337+
type=str,
338+
choices=["wan", "ltx2"],
339+
help="VAE type to use for saving full pipeline (ltx2 uses patchify 1x1x1).",
340+
)
341+
parser.add_argument(
342+
"--vae_path",
343+
default=None,
344+
type=str,
345+
required=False,
346+
help="Optional VAE path or repo id. If not set, a default is used per VAE type.",
347+
)
348+
parser.add_argument(
349+
"--task", default="t2v", type=str, required=True, choices=["t2v", "i2v"], help="Task to convert, t2v or i2v."
350+
)
318351
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
319352
parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipeline elements in one.")
320353
parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")

0 commit comments

Comments
 (0)