Skip to content

Commit ea436c4

Browse files
committed
update
1 parent a74f02f commit ea436c4

File tree

2 files changed

+277
-12
lines changed

2 files changed

+277
-12
lines changed

scripts/convert_ltx_to_diffusers.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,32 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]):
7474
"last_scale_shift_table": "scale_shift_table",
7575
}
7676

77+
VAE_095_RENAME_DICT = {
78+
# decoder
79+
"up_blocks.0": "mid_block",
80+
"up_blocks.1": "up_blocks.0.upsamplers.0",
81+
"up_blocks.2": "up_blocks.0",
82+
"up_blocks.3": "up_blocks.1.upsamplers.0",
83+
"up_blocks.4": "up_blocks.1",
84+
"up_blocks.5": "up_blocks.2.upsamplers.0",
85+
"up_blocks.6": "up_blocks.2",
86+
"up_blocks.7": "up_blocks.3.upsamplers.0",
87+
"up_blocks.8": "up_blocks.3",
88+
# encoder
89+
"down_blocks.0": "down_blocks.0",
90+
"down_blocks.1": "down_blocks.0.downsamplers.0",
91+
"down_blocks.2": "down_blocks.1",
92+
"down_blocks.3": "down_blocks.1.downsamplers.0",
93+
"down_blocks.4": "down_blocks.2",
94+
"down_blocks.5": "down_blocks.2.downsamplers.0",
95+
"down_blocks.6": "down_blocks.3",
96+
"down_blocks.7": "down_blocks.3.downsamplers.0",
97+
"down_blocks.8": "mid_block",
98+
# common
99+
"last_time_embedder": "time_embedder",
100+
"last_scale_shift_table": "scale_shift_table",
101+
}
102+
77103
VAE_SPECIAL_KEYS_REMAP = {
78104
"per_channel_statistics.channel": remove_keys_,
79105
"per_channel_statistics.mean-of-means": remove_keys_,
@@ -85,6 +111,10 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]):
85111
"timestep_scale_multiplier": remove_keys_,
86112
}
87113

114+
VAE_095_SPECIAL_KEYS_REMAP = {
115+
"timestep_scale_multiplier": remove_keys_,
116+
}
117+
88118

89119
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
90120
state_dict = saved_dict
@@ -161,12 +191,19 @@ def get_vae_config(version: str) -> Dict[str, Any]:
161191
"out_channels": 3,
162192
"latent_channels": 128,
163193
"block_out_channels": (128, 256, 512, 512),
194+
"down_block_types": (
195+
"LTXVideoDownBlock3D",
196+
"LTXVideoDownBlock3D",
197+
"LTXVideoDownBlock3D",
198+
"LTXVideoDownBlock3D",
199+
),
164200
"decoder_block_out_channels": (128, 256, 512, 512),
165201
"layers_per_block": (4, 3, 3, 3, 4),
166202
"decoder_layers_per_block": (4, 3, 3, 3, 4),
167203
"spatio_temporal_scaling": (True, True, True, False),
168204
"decoder_spatio_temporal_scaling": (True, True, True, False),
169205
"decoder_inject_noise": (False, False, False, False, False),
206+
"downsample_type": ("conv", "conv", "conv", "conv"),
170207
"upsample_residual": (False, False, False, False),
171208
"upsample_factor": (1, 1, 1, 1),
172209
"patch_size": 4,
@@ -183,12 +220,19 @@ def get_vae_config(version: str) -> Dict[str, Any]:
183220
"out_channels": 3,
184221
"latent_channels": 128,
185222
"block_out_channels": (128, 256, 512, 512),
223+
"down_block_types": (
224+
"LTXVideoDownBlock3D",
225+
"LTXVideoDownBlock3D",
226+
"LTXVideoDownBlock3D",
227+
"LTXVideoDownBlock3D",
228+
),
186229
"decoder_block_out_channels": (256, 512, 1024),
187230
"layers_per_block": (4, 3, 3, 3, 4),
188231
"decoder_layers_per_block": (5, 6, 7, 8),
189232
"spatio_temporal_scaling": (True, True, True, False),
190233
"decoder_spatio_temporal_scaling": (True, True, True),
191234
"decoder_inject_noise": (True, True, True, False),
235+
"downsample_type": ("conv", "conv", "conv", "conv"),
192236
"upsample_residual": (True, True, True),
193237
"upsample_factor": (2, 2, 2),
194238
"timestep_conditioning": True,
@@ -201,6 +245,37 @@ def get_vae_config(version: str) -> Dict[str, Any]:
201245
}
202246
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
203247
VAE_SPECIAL_KEYS_REMAP.update(VAE_091_SPECIAL_KEYS_REMAP)
248+
elif version == "0.9.5":
249+
config = {
250+
"in_channels": 3,
251+
"out_channels": 3,
252+
"latent_channels": 128,
253+
"block_out_channels": (128, 256, 512, 1024, 2048),
254+
"down_block_types": (
255+
"LTXVideo095DownBlock3D",
256+
"LTXVideo095DownBlock3D",
257+
"LTXVideo095DownBlock3D",
258+
"LTXVideo095DownBlock3D",
259+
),
260+
"decoder_block_out_channels": (256, 512, 1024),
261+
"layers_per_block": (4, 6, 6, 2, 2),
262+
"decoder_layers_per_block": (5, 5, 5, 5),
263+
"spatio_temporal_scaling": (True, True, True, True),
264+
"decoder_spatio_temporal_scaling": (True, True, True),
265+
"decoder_inject_noise": (False, False, False, False),
266+
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
267+
"upsample_residual": (True, True, True),
268+
"upsample_factor": (2, 2, 2),
269+
"timestep_conditioning": True,
270+
"patch_size": 4,
271+
"patch_size_t": 1,
272+
"resnet_norm_eps": 1e-6,
273+
"scaling_factor": 1.0,
274+
"encoder_causal": True,
275+
"decoder_causal": False,
276+
}
277+
VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
278+
VAE_SPECIAL_KEYS_REMAP.update(VAE_095_SPECIAL_KEYS_REMAP)
204279
return config
205280

206281

@@ -223,7 +298,7 @@ def get_args():
223298
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
224299
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
225300
parser.add_argument(
226-
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1"], help="Version of the LTX model"
301+
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1", "0.9.5"], help="Version of the LTX model"
227302
)
228303
return parser.parse_args()
229304

0 commit comments

Comments
 (0)