Skip to content

Commit 0871dc6

Browse files
committed
update
1 parent b756ec6 commit 0871dc6

File tree

2 files changed

+292
-28
lines changed

2 files changed

+292
-28
lines changed

scripts/convert_ltx_to_diffusers.py

Lines changed: 147 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import argparse
22
from typing import Any, Dict
3+
from pathlib import Path
34

45
import torch
6+
from accelerate import init_empty_weights
57
from safetensors.torch import load_file
68
from transformers import T5EncoderModel, T5Tokenizer
79

@@ -21,7 +23,9 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]):
2123
"k_norm": "norm_k",
2224
}
2325

24-
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
26+
TRANSFORMER_SPECIAL_KEYS_REMAP = {
27+
"vae": remove_keys_,
28+
}
2529

2630
VAE_KEYS_RENAME_DICT = {
2731
# decoder
@@ -54,10 +58,33 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]):
5458
"per_channel_statistics.std-of-means": "latents_std",
5559
}
5660

61+
VAE_091_RENAME_DICT = {
62+
# decoder
63+
"up_blocks.0": "mid_block",
64+
"up_blocks.1": "up_blocks.0.upsamplers.0",
65+
"up_blocks.2": "up_blocks.0",
66+
"up_blocks.3": "up_blocks.1.upsamplers.0",
67+
"up_blocks.4": "up_blocks.1",
68+
"up_blocks.5": "up_blocks.2.upsamplers.0",
69+
"up_blocks.6": "up_blocks.2",
70+
"up_blocks.7": "up_blocks.3.upsamplers.0",
71+
"up_blocks.8": "up_blocks.3",
72+
# common
73+
"per_channel_scale1": "scale1",
74+
"per_channel_scale2": "scale2",
75+
"last_time_embedder": "time_embedder",
76+
"last_scale_shift_table": "scale_shift_table",
77+
}
78+
5779
VAE_SPECIAL_KEYS_REMAP = {
5880
"per_channel_statistics.channel": remove_keys_,
5981
"per_channel_statistics.mean-of-means": remove_keys_,
6082
"per_channel_statistics.mean-of-stds": remove_keys_,
83+
"model.diffusion_model": remove_keys_,
84+
}
85+
86+
VAE_091_SPECIAL_KEYS_REMAP = {
87+
"timestep_scale_multiplier": remove_keys_,
6188
}
6289

6390

@@ -80,13 +107,16 @@ def convert_transformer(
80107
ckpt_path: str,
81108
dtype: torch.dtype,
82109
):
83-
PREFIX_KEY = ""
110+
PREFIX_KEY = "model.diffusion_model."
84111

85112
original_state_dict = get_state_dict(load_file(ckpt_path))
86-
transformer = LTXVideoTransformer3DModel().to(dtype=dtype)
113+
with init_empty_weights():
114+
transformer = LTXVideoTransformer3DModel()
87115

88116
for key in list(original_state_dict.keys()):
89-
new_key = key[len(PREFIX_KEY) :]
117+
new_key = key[:]
118+
if new_key.startswith(PREFIX_KEY):
119+
new_key = key[len(PREFIX_KEY) :]
90120
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
91121
new_key = new_key.replace(replace_key, rename_key)
92122
update_state_dict_inplace(original_state_dict, key, new_key)
@@ -97,16 +127,21 @@ def convert_transformer(
97127
continue
98128
handler_fn_inplace(key, original_state_dict)
99129

100-
transformer.load_state_dict(original_state_dict, strict=True)
130+
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
101131
return transformer
102132

103133

104-
def convert_vae(ckpt_path: str, dtype: torch.dtype):
134+
def convert_vae(ckpt_path: str, config, dtype: torch.dtype):
135+
PREFIX_KEY = "vae."
136+
105137
original_state_dict = get_state_dict(load_file(ckpt_path))
106-
vae = AutoencoderKLLTXVideo().to(dtype=dtype)
138+
with init_empty_weights():
139+
vae = AutoencoderKLLTXVideo(**config)
107140

108141
for key in list(original_state_dict.keys()):
109142
new_key = key[:]
143+
if new_key.startswith(PREFIX_KEY):
144+
new_key = key[len(PREFIX_KEY) :]
110145
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
111146
new_key = new_key.replace(replace_key, rename_key)
112147
update_state_dict_inplace(original_state_dict, key, new_key)
@@ -117,9 +152,107 @@ def convert_vae(ckpt_path: str, dtype: torch.dtype):
117152
continue
118153
handler_fn_inplace(key, original_state_dict)
119154

120-
vae.load_state_dict(original_state_dict, strict=True)
155+
vae.load_state_dict(original_state_dict, strict=True, assign=True)
121156
return vae
122157

158+
# OURS_VAE_CONFIG = {
159+
# "_class_name": "CausalVideoAutoencoder",
160+
# "dims": 3,
161+
# "in_channels": 3,
162+
# "out_channels": 3,
163+
# "latent_channels": 128,
164+
# "blocks": [
165+
# ["res_x", 4],
166+
# ["compress_all", 1],
167+
# ["res_x_y", 1],
168+
# ["res_x", 3],
169+
# ["compress_all", 1],
170+
# ["res_x_y", 1],
171+
# ["res_x", 3],
172+
# ["compress_all", 1],
173+
# ["res_x", 3],
174+
# ["res_x", 4],
175+
# ],
176+
# "scaling_factor": 1.0,
177+
# "norm_layer": "pixel_norm",
178+
# "patch_size": 4,
179+
# "latent_log_var": "uniform",
180+
# "use_quant_conv": False,
181+
# "causal_decoder": False,
182+
# }
183+
184+
# {
185+
# "_class_name": "CausalVideoAutoencoder",
186+
# "dims": 3, "in_channels": 3, "out_channels": 3, "latent_channels": 128,
187+
# "encoder_blocks": [["res_x", {"num_layers": 4}], ["compress_all", {}], ["res_x_y", 1], ["res_x", {"num_layers": 3}], ["compress_all", {}], ["res_x_y", 1], ["res_x", {"num_layers": 3}], ["compress_all", {}], ["res_x", {"num_layers": 3}], ["res_x", {"num_layers": 4}]],
188+
189+
# previous decoder
190+
# mid: resx
191+
# resx
192+
# compress_all, resx
193+
# resxy, compress_all, resx
194+
# resxy, compress_all, resx
195+
196+
# "decoder_blocks": [["res_x", {"num_layers": 5, "inject_noise": true}], ["compress_all", {"residual": true, "multiplier": 2}], ["res_x", {"num_layers": 6, "inject_noise": true}], ["compress_all", {"residual": true, "multiplier": 2}], ["res_x", {"num_layers": 7, "inject_noise": true}], ["compress_all", {"residual": true, "multiplier": 2}], ["res_x", {"num_layers": 8, "inject_noise": false}]],
197+
198+
# current decoder
199+
# mid: resx
200+
# compress_all, resx
201+
# compress_all, resx
202+
# compress_all, resx
203+
204+
# "scaling_factor": 1.0, "norm_layer": "pixel_norm", "patch_size": 4, "latent_log_var": "uniform", "use_quant_conv": false, "causal_decoder": false, "timestep_conditioning": true
205+
# }
206+
207+
def get_vae_config(version: str) -> Dict[str, Any]:
208+
if version == "0.9.0":
209+
config = {
210+
"in_channels": 3,
211+
"out_channels": 3,
212+
"latent_channels": 128,
213+
"block_out_channels": (128, 256, 512, 512),
214+
"decoder_block_out_channels": (128, 256, 512, 512),
215+
"layers_per_block": (4, 3, 3, 3, 4),
216+
"decoder_layers_per_block": (4, 3, 3, 3, 4),
217+
"spatio_temporal_scaling": (True, True, True, False),
218+
"decoder_spatio_temporal_scaling": (True, True, True, False),
219+
"decoder_inject_noise": (False, False, False, False),
220+
"upsample_residual": (False, False, False, False),
221+
"upsample_factor": (1, 1, 1, 1),
222+
"patch_size": 4,
223+
"patch_size_t": 1,
224+
"resnet_norm_eps": 1e-6,
225+
"scaling_factor": 1.0,
226+
"encoder_causal": True,
227+
"decoder_causal": False,
228+
"timestep_conditioning": False,
229+
}
230+
elif version == "0.9.1":
231+
config = {
232+
"in_channels": 3,
233+
"out_channels": 3,
234+
"latent_channels": 128,
235+
"block_out_channels": (128, 256, 512, 512),
236+
"decoder_block_out_channels": (256, 512, 1024),
237+
"layers_per_block": (4, 3, 3, 3, 4),
238+
"decoder_layers_per_block": (5, 6, 7, 8),
239+
"spatio_temporal_scaling": (True, True, True, False),
240+
"decoder_spatio_temporal_scaling": (True, True, True),
241+
"decoder_inject_noise": (False, True, True, True),
242+
"upsample_residual": (True, True, True),
243+
"upsample_factor": (2, 2, 2),
244+
"timestep_conditioning": True,
245+
"patch_size": 4,
246+
"patch_size_t": 1,
247+
"resnet_norm_eps": 1e-6,
248+
"scaling_factor": 1.0,
249+
"encoder_causal": True,
250+
"decoder_causal": False,
251+
}
252+
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
253+
VAE_SPECIAL_KEYS_REMAP.update(VAE_091_SPECIAL_KEYS_REMAP)
254+
return config
255+
123256

124257
def get_args():
125258
parser = argparse.ArgumentParser()
@@ -139,6 +272,7 @@ def get_args():
139272
parser.add_argument("--save_pipeline", action="store_true")
140273
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
141274
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
275+
parser.add_argument("--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1"], help="Version of the LTX model")
142276
return parser.parse_args()
143277

144278

@@ -161,6 +295,7 @@ def get_args():
161295
transformer = None
162296
dtype = DTYPE_MAPPING[args.dtype]
163297
variant = VARIANT_MAPPING[args.dtype]
298+
output_path = Path(args.output_path)
164299

165300
if args.save_pipeline:
166301
assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None
@@ -169,13 +304,14 @@ def get_args():
169304
transformer: LTXVideoTransformer3DModel = convert_transformer(args.transformer_ckpt_path, dtype)
170305
if not args.save_pipeline:
171306
transformer.save_pretrained(
172-
args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant
307+
output_path / "transformer", safe_serialization=True, max_shard_size="5GB", variant=variant
173308
)
174309

175310
if args.vae_ckpt_path is not None:
176-
vae: AutoencoderKLLTXVideo = convert_vae(args.vae_ckpt_path, dtype)
311+
config = get_vae_config(args.version)
312+
vae: AutoencoderKLLTXVideo = convert_vae(args.vae_ckpt_path, config, dtype)
177313
if not args.save_pipeline:
178-
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant)
314+
vae.save_pretrained(output_path / "vae", safe_serialization=True, max_shard_size="5GB", variant=variant)
179315

180316
if args.save_pipeline:
181317
text_encoder_id = "google/t5-v1_1-xxl"

0 commit comments

Comments
 (0)