Skip to content

Commit 58a51aa

Browse files
committed
make style
1 parent 0871dc6 commit 58a51aa

File tree

2 files changed

+27
-59
lines changed

2 files changed

+27
-59
lines changed

scripts/convert_ltx_to_diffusers.py

Lines changed: 5 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import argparse
2-
from typing import Any, Dict
32
from pathlib import Path
3+
from typing import Any, Dict
44

55
import torch
66
from accelerate import init_empty_weights
@@ -133,7 +133,7 @@ def convert_transformer(
133133

134134
def convert_vae(ckpt_path: str, config, dtype: torch.dtype):
135135
PREFIX_KEY = "vae."
136-
136+
137137
original_state_dict = get_state_dict(load_file(ckpt_path))
138138
with init_empty_weights():
139139
vae = AutoencoderKLLTXVideo(**config)
@@ -155,54 +155,6 @@ def convert_vae(ckpt_path: str, config, dtype: torch.dtype):
155155
vae.load_state_dict(original_state_dict, strict=True, assign=True)
156156
return vae
157157

158-
# OURS_VAE_CONFIG = {
159-
# "_class_name": "CausalVideoAutoencoder",
160-
# "dims": 3,
161-
# "in_channels": 3,
162-
# "out_channels": 3,
163-
# "latent_channels": 128,
164-
# "blocks": [
165-
# ["res_x", 4],
166-
# ["compress_all", 1],
167-
# ["res_x_y", 1],
168-
# ["res_x", 3],
169-
# ["compress_all", 1],
170-
# ["res_x_y", 1],
171-
# ["res_x", 3],
172-
# ["compress_all", 1],
173-
# ["res_x", 3],
174-
# ["res_x", 4],
175-
# ],
176-
# "scaling_factor": 1.0,
177-
# "norm_layer": "pixel_norm",
178-
# "patch_size": 4,
179-
# "latent_log_var": "uniform",
180-
# "use_quant_conv": False,
181-
# "causal_decoder": False,
182-
# }
183-
184-
# {
185-
# "_class_name": "CausalVideoAutoencoder",
186-
# "dims": 3, "in_channels": 3, "out_channels": 3, "latent_channels": 128,
187-
# "encoder_blocks": [["res_x", {"num_layers": 4}], ["compress_all", {}], ["res_x_y", 1], ["res_x", {"num_layers": 3}], ["compress_all", {}], ["res_x_y", 1], ["res_x", {"num_layers": 3}], ["compress_all", {}], ["res_x", {"num_layers": 3}], ["res_x", {"num_layers": 4}]],
188-
189-
# previous decoder
190-
# mid: resx
191-
# resx
192-
# compress_all, resx
193-
# resxy, compress_all, resx
194-
# resxy, compress_all, resx
195-
196-
# "decoder_blocks": [["res_x", {"num_layers": 5, "inject_noise": true}], ["compress_all", {"residual": true, "multiplier": 2}], ["res_x", {"num_layers": 6, "inject_noise": true}], ["compress_all", {"residual": true, "multiplier": 2}], ["res_x", {"num_layers": 7, "inject_noise": true}], ["compress_all", {"residual": true, "multiplier": 2}], ["res_x", {"num_layers": 8, "inject_noise": false}]],
197-
198-
# current decoder
199-
# mid: resx
200-
# compress_all, resx
201-
# compress_all, resx
202-
# compress_all, resx
203-
204-
# "scaling_factor": 1.0, "norm_layer": "pixel_norm", "patch_size": 4, "latent_log_var": "uniform", "use_quant_conv": false, "causal_decoder": false, "timestep_conditioning": true
205-
# }
206158

207159
def get_vae_config(version: str) -> Dict[str, Any]:
208160
if version == "0.9.0":
@@ -272,7 +224,9 @@ def get_args():
272224
parser.add_argument("--save_pipeline", action="store_true")
273225
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
274226
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
275-
parser.add_argument("--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1"], help="Version of the LTX model")
227+
parser.add_argument(
228+
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1"], help="Version of the LTX model"
229+
)
276230
return parser.parse_args()
277231

278232

src/diffusers/models/autoencoders/autoencoder_kl_ltx.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -137,13 +137,13 @@ def __init__(
137137
self.conv_shortcut = LTXCausalConv3d(
138138
in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, is_causal=is_causal
139139
)
140-
140+
141141
self.scale1 = None
142142
self.scale2 = None
143143
if inject_noise:
144144
self.scale1 = nn.Parameter(torch.zeros(in_channels, 1, 1))
145145
self.scale2 = nn.Parameter(torch.zeros(in_channels, 1, 1))
146-
146+
147147
self.scale_shift_table = None
148148
if timestep_conditioning:
149149
self.scale_shift_table = nn.Parameter(torch.randn(4, in_channels) / in_channels**0.5)
@@ -166,7 +166,7 @@ def forward(self, inputs: torch.Tensor, temb: Optional[torch.Tensor] = None) ->
166166

167167
if self.scale_shift_table is not None:
168168
hidden_states = hidden_states * (1 + scale_1) + shift_1
169-
169+
170170
hidden_states = self.nonlinearity(hidden_states)
171171
hidden_states = self.dropout(hidden_states)
172172
hidden_states = self.conv2(hidden_states)
@@ -211,7 +211,6 @@ def __init__(
211211
is_causal=is_causal,
212212
)
213213

214-
215214
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
216215
batch_size, num_channels, num_frames, height, width = hidden_states.shape
217216

@@ -495,7 +494,17 @@ def __init__(
495494

496495
self.upsamplers = None
497496
if spatio_temporal_scale:
498-
self.upsamplers = nn.ModuleList([LTXUpsampler3d(out_channels * upscale_factor, stride=(2, 2, 2), is_causal=is_causal, residual=upsample_residual, upscale_factor=upscale_factor)])
497+
self.upsamplers = nn.ModuleList(
498+
[
499+
LTXUpsampler3d(
500+
out_channels * upscale_factor,
501+
stride=(2, 2, 2),
502+
is_causal=is_causal,
503+
residual=upsample_residual,
504+
upscale_factor=upscale_factor,
505+
)
506+
]
507+
)
499508

500509
resnets = []
501510
for _ in range(num_layers):
@@ -508,7 +517,7 @@ def __init__(
508517
non_linearity=resnet_act_fn,
509518
is_causal=is_causal,
510519
inject_noise=inject_noise,
511-
timestep_conditioning=timestep_conditioning
520+
timestep_conditioning=timestep_conditioning,
512521
)
513522
)
514523
self.resnets = nn.ModuleList(resnets)
@@ -518,7 +527,7 @@ def __init__(
518527
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
519528
if self.conv_in is not None:
520529
hidden_states = self.conv_in(hidden_states)
521-
530+
522531
if self.time_embedder is not None:
523532
temb = self.time_embedder(
524533
timestep=temb.flatten(),
@@ -744,7 +753,12 @@ def __init__(
744753
)
745754

746755
self.mid_block = LTXMidBlock3d(
747-
in_channels=output_channel, num_layers=layers_per_block[0], resnet_eps=resnet_norm_eps, is_causal=is_causal, inject_noise=inject_noise[0], timestep_conditioning=timestep_conditioning
756+
in_channels=output_channel,
757+
num_layers=layers_per_block[0],
758+
resnet_eps=resnet_norm_eps,
759+
is_causal=is_causal,
760+
inject_noise=inject_noise[0],
761+
timestep_conditioning=timestep_conditioning,
748762
)
749763

750764
# up blocks

0 commit comments

Comments
 (0)