Skip to content

Commit 5316f4b

Browse files
committed
update
1 parent 58a51aa commit 5316f4b

File tree

5 files changed

+294
-48
lines changed

5 files changed

+294
-48
lines changed

scripts/convert_ltx_to_diffusers.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,6 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]):
7070
"up_blocks.7": "up_blocks.3.upsamplers.0",
7171
"up_blocks.8": "up_blocks.3",
7272
# common
73-
"per_channel_scale1": "scale1",
74-
"per_channel_scale2": "scale2",
7573
"last_time_embedder": "time_embedder",
7674
"last_scale_shift_table": "scale_shift_table",
7775
}
@@ -168,7 +166,7 @@ def get_vae_config(version: str) -> Dict[str, Any]:
168166
"decoder_layers_per_block": (4, 3, 3, 3, 4),
169167
"spatio_temporal_scaling": (True, True, True, False),
170168
"decoder_spatio_temporal_scaling": (True, True, True, False),
171-
"decoder_inject_noise": (False, False, False, False),
169+
"decoder_inject_noise": (False, False, False, False, False),
172170
"upsample_residual": (False, False, False, False),
173171
"upsample_factor": (1, 1, 1, 1),
174172
"patch_size": 4,
@@ -190,7 +188,7 @@ def get_vae_config(version: str) -> Dict[str, Any]:
190188
"decoder_layers_per_block": (5, 6, 7, 8),
191189
"spatio_temporal_scaling": (True, True, True, False),
192190
"decoder_spatio_temporal_scaling": (True, True, True),
193-
"decoder_inject_noise": (False, True, True, True),
191+
"decoder_inject_noise": (True, True, True, False),
194192
"upsample_residual": (True, True, True),
195193
"upsample_factor": (2, 2, 2),
196194
"timestep_conditioning": True,

src/diffusers/models/autoencoders/autoencoder_kl_ltx.py

Lines changed: 92 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -138,43 +138,53 @@ def __init__(
138138
in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, is_causal=is_causal
139139
)
140140

141-
self.scale1 = None
142-
self.scale2 = None
141+
self.per_channel_scale1 = None
142+
self.per_channel_scale2 = None
143143
if inject_noise:
144-
self.scale1 = nn.Parameter(torch.zeros(in_channels, 1, 1))
145-
self.scale2 = nn.Parameter(torch.zeros(in_channels, 1, 1))
144+
self.per_channel_scale1 = nn.Parameter(torch.zeros(in_channels, 1, 1))
145+
self.per_channel_scale2 = nn.Parameter(torch.zeros(in_channels, 1, 1))
146146

147147
self.scale_shift_table = None
148148
if timestep_conditioning:
149149
self.scale_shift_table = nn.Parameter(torch.randn(4, in_channels) / in_channels**0.5)
150150

151-
def forward(self, inputs: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
151+
def forward(
152+
self, inputs: torch.Tensor, temb: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None
153+
) -> torch.Tensor:
152154
hidden_states = inputs
153155

154156
hidden_states = self.norm1(hidden_states.movedim(1, -1)).movedim(-1, 1)
155-
scale_1, shift_1, scale_2, shift_2 = self.scale_shift_table.unbind(dim=0)
157+
158+
if self.scale_shift_table is not None:
159+
temb = temb.unflatten(1, (4, -1)) + self.scale_shift_table[None, ..., None, None, None]
160+
shift_1, scale_1, shift_2, scale_2 = temb.unbind(dim=1)
161+
hidden_states = hidden_states * (1 + scale_1) + shift_1
156162

157163
hidden_states = self.nonlinearity(hidden_states)
158164
hidden_states = self.conv1(hidden_states)
159165

160-
if self.scale1 is not None:
166+
if self.per_channel_scale1 is not None:
161167
spatial_shape = hidden_states.shape[-2:]
162-
spatial_noise = torch.randn(spatial_shape, device=hidden_states.device, dtype=hidden_states.dtype)
163-
hidden_states = hidden_states + (spatial_noise * self.scale1)[None, :, None, :, :]
168+
spatial_noise = torch.randn(
169+
spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype
170+
)
171+
hidden_states = hidden_states + (spatial_noise * self.per_channel_scale1)[None, :, None, :, :]
164172

165173
hidden_states = self.norm2(hidden_states.movedim(1, -1)).movedim(-1, 1)
166174

167175
if self.scale_shift_table is not None:
168-
hidden_states = hidden_states * (1 + scale_1) + shift_1
176+
hidden_states = hidden_states * (1 + scale_2) + shift_2
169177

170178
hidden_states = self.nonlinearity(hidden_states)
171179
hidden_states = self.dropout(hidden_states)
172180
hidden_states = self.conv2(hidden_states)
173181

174-
if self.scale2 is not None:
182+
if self.per_channel_scale2 is not None:
175183
spatial_shape = hidden_states.shape[-2:]
176-
spatial_noise = torch.randn(spatial_shape, device=hidden_states.device, dtype=hidden_states.dtype)
177-
hidden_states = hidden_states + (spatial_noise * self.scale2)[None, :, None, :, :]
184+
spatial_noise = torch.randn(
185+
spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype
186+
)
187+
hidden_states = hidden_states + (spatial_noise * self.per_channel_scale2)[None, :, None, :, :]
178188

179189
if self.norm3 is not None:
180190
inputs = self.norm3(inputs.movedim(1, -1)).movedim(-1, 1)
@@ -318,7 +328,12 @@ def __init__(
318328

319329
self.gradient_checkpointing = False
320330

321-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
331+
def forward(
332+
self,
333+
hidden_states: torch.Tensor,
334+
temb: Optional[torch.Tensor] = None,
335+
generator: Optional[torch.Generator] = None,
336+
) -> torch.Tensor:
322337
r"""Forward method of the `LTXDownBlock3D` class."""
323338

324339
for i, resnet in enumerate(self.resnets):
@@ -330,16 +345,18 @@ def create_forward(*inputs):
330345

331346
return create_forward
332347

333-
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states)
348+
hidden_states = torch.utils.checkpoint.checkpoint(
349+
create_custom_forward(resnet), hidden_states, temb, generator
350+
)
334351
else:
335-
hidden_states = resnet(hidden_states)
352+
hidden_states = resnet(hidden_states, temb, generator)
336353

337354
if self.downsamplers is not None:
338355
for downsampler in self.downsamplers:
339356
hidden_states = downsampler(hidden_states)
340357

341358
if self.conv_out is not None:
342-
hidden_states = self.conv_out(hidden_states)
359+
hidden_states = self.conv_out(hidden_states, temb, generator)
343360

344361
return hidden_states
345362

@@ -401,7 +418,12 @@ def __init__(
401418

402419
self.gradient_checkpointing = False
403420

404-
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
421+
def forward(
422+
self,
423+
hidden_states: torch.Tensor,
424+
temb: Optional[torch.Tensor] = None,
425+
generator: Optional[torch.Generator] = None,
426+
) -> torch.Tensor:
405427
r"""Forward method of the `LTXMidBlock3D` class."""
406428

407429
if self.time_embedder is not None:
@@ -423,9 +445,11 @@ def create_forward(*inputs):
423445

424446
return create_forward
425447

426-
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
448+
hidden_states = torch.utils.checkpoint.checkpoint(
449+
create_custom_forward(resnet), hidden_states, temb, generator
450+
)
427451
else:
428-
hidden_states = resnet(hidden_states, temb)
452+
hidden_states = resnet(hidden_states, temb, generator)
429453

430454
return hidden_states
431455

@@ -524,9 +548,14 @@ def __init__(
524548

525549
self.gradient_checkpointing = False
526550

527-
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
551+
def forward(
552+
self,
553+
hidden_states: torch.Tensor,
554+
temb: Optional[torch.Tensor] = None,
555+
generator: Optional[torch.Generator] = None,
556+
) -> torch.Tensor:
528557
if self.conv_in is not None:
529-
hidden_states = self.conv_in(hidden_states)
558+
hidden_states = self.conv_in(hidden_states, temb, generator)
530559

531560
if self.time_embedder is not None:
532561
temb = self.time_embedder(
@@ -551,9 +580,11 @@ def create_forward(*inputs):
551580

552581
return create_forward
553582

554-
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states)
583+
hidden_states = torch.utils.checkpoint.checkpoint(
584+
create_custom_forward(resnet), hidden_states, temb, generator
585+
)
555586
else:
556-
hidden_states = resnet(hidden_states)
587+
hidden_states = resnet(hidden_states, temb, generator)
557588

558589
return hidden_states
559590

@@ -746,6 +777,9 @@ def __init__(
746777
block_out_channels = tuple(reversed(block_out_channels))
747778
spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling))
748779
layers_per_block = tuple(reversed(layers_per_block))
780+
inject_noise = tuple(reversed(inject_noise))
781+
upsample_residual = tuple(reversed(upsample_residual))
782+
upsample_factor = tuple(reversed(upsample_factor))
749783
output_channel = block_out_channels[0]
750784

751785
self.conv_in = LTXCausalConv3d(
@@ -810,29 +844,31 @@ def create_forward(*inputs):
810844

811845
return create_forward
812846

813-
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), hidden_states)
847+
hidden_states = torch.utils.checkpoint.checkpoint(
848+
create_custom_forward(self.mid_block), hidden_states, temb
849+
)
814850

815851
for up_block in self.up_blocks:
816-
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), hidden_states)
852+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), hidden_states, temb)
817853
else:
818-
hidden_states = self.mid_block(hidden_states)
854+
hidden_states = self.mid_block(hidden_states, temb)
819855

820856
for up_block in self.up_blocks:
821-
hidden_states = up_block(hidden_states)
857+
hidden_states = up_block(hidden_states, temb)
822858

823859
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
824860

825861
if self.time_embedder is not None:
826-
embedded_timestep = self.time_embedder(
862+
temb = self.time_embedder(
827863
timestep=temb.flatten(),
828864
resolution=None,
829865
aspect_ratio=None,
830866
batch_size=hidden_states.size(0),
831867
hidden_dtype=hidden_states.dtype,
832868
)
833-
embedded_timestep = embedded_timestep.view(hidden_states.size(0), -1, 1, 1, 1).unflatten(1, (2, -1))
834-
embedded_timestep = embedded_timestep + self.scale_shift_table[None, :, None, None, None]
835-
shift, scale = embedded_timestep.unbind(dim=1)
869+
temb = temb.view(hidden_states.size(0), -1, 1, 1, 1).unflatten(1, (2, -1))
870+
temb = temb + self.scale_shift_table[None, ..., None, None, None]
871+
shift, scale = temb.unbind(dim=1)
836872
hidden_states = hidden_states * (1 + scale) + shift
837873

838874
hidden_states = self.conv_act(hidden_states)
@@ -902,7 +938,7 @@ def __init__(
902938
decoder_layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
903939
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
904940
decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
905-
decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False),
941+
decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False, False),
906942
upsample_residual: Tuple[bool, ...] = (False, False, False, False),
907943
upsample_factor: Tuple[int, ...] = (1, 1, 1, 1),
908944
timestep_conditioning: bool = False,
@@ -1078,13 +1114,15 @@ def encode(
10781114
return (posterior,)
10791115
return AutoencoderKLOutput(latent_dist=posterior)
10801116

1081-
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1117+
def _decode(
1118+
self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True
1119+
) -> Union[DecoderOutput, torch.Tensor]:
10821120
batch_size, num_channels, num_frames, height, width = z.shape
10831121
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
10841122
tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
10851123

10861124
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
1087-
return self.tiled_decode(z, return_dict=return_dict)
1125+
return self.tiled_decode(z, temb, return_dict=return_dict)
10881126

10891127
if self.use_framewise_decoding:
10901128
# TODO(aryan): requires investigation
@@ -1094,15 +1132,17 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOut
10941132
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
10951133
)
10961134
else:
1097-
dec = self.decoder(z)
1135+
dec = self.decoder(z, temb)
10981136

10991137
if not return_dict:
11001138
return (dec,)
11011139

11021140
return DecoderOutput(sample=dec)
11031141

11041142
@apply_forward_hook
1105-
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1143+
def decode(
1144+
self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True
1145+
) -> Union[DecoderOutput, torch.Tensor]:
11061146
"""
11071147
Decode a batch of images.
11081148
@@ -1117,10 +1157,15 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp
11171157
returned.
11181158
"""
11191159
if self.use_slicing and z.shape[0] > 1:
1120-
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
1160+
if temb is not None:
1161+
decoded_slices = [
1162+
self._decode(z_slice, t_slice).sample for z_slice, t_slice in (z.split(1), temb.split(1))
1163+
]
1164+
else:
1165+
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
11211166
decoded = torch.cat(decoded_slices)
11221167
else:
1123-
decoded = self._decode(z).sample
1168+
decoded = self._decode(z, temb).sample
11241169

11251170
if not return_dict:
11261171
return (decoded,)
@@ -1202,7 +1247,9 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
12021247
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
12031248
return enc
12041249

1205-
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1250+
def tiled_decode(
1251+
self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True
1252+
) -> Union[DecoderOutput, torch.Tensor]:
12061253
r"""
12071254
Decode a batch of images using a tiled decoder.
12081255
@@ -1243,7 +1290,9 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod
12431290
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
12441291
)
12451292
else:
1246-
time = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width])
1293+
time = self.decoder(
1294+
z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb
1295+
)
12471296

12481297
row.append(time)
12491298
rows.append(row)
@@ -1271,6 +1320,7 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod
12711320
def forward(
12721321
self,
12731322
sample: torch.Tensor,
1323+
temb: Optional[torch.Tensor] = None,
12741324
sample_posterior: bool = False,
12751325
return_dict: bool = True,
12761326
generator: Optional[torch.Generator] = None,
@@ -1281,7 +1331,7 @@ def forward(
12811331
z = posterior.sample(generator=generator)
12821332
else:
12831333
z = posterior.mode()
1284-
dec = self.decode(z)
1334+
dec = self.decode(z, temb)
12851335
if not return_dict:
12861336
return (dec,)
12871337
return dec

src/diffusers/pipelines/ltx/pipeline_ltx.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,8 @@ def __call__(
511511
prompt_attention_mask: Optional[torch.Tensor] = None,
512512
negative_prompt_embeds: Optional[torch.Tensor] = None,
513513
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
514+
decode_timestep: Union[float, List[float]] = 0.05,
515+
decode_noise_scale: Union[float, List[float]] = 0.025,
514516
output_type: Optional[str] = "pil",
515517
return_dict: bool = True,
516518
attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -753,7 +755,25 @@ def __call__(
753755
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
754756
)
755757
latents = latents.to(prompt_embeds.dtype)
756-
video = self.vae.decode(latents, return_dict=False)[0]
758+
759+
if not self.vae.config.timestep_conditioning:
760+
timestep = None
761+
else:
762+
noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype)
763+
if not isinstance(decode_timestep, list):
764+
decode_timestep = [decode_timestep] * batch_size
765+
if decode_noise_scale is None:
766+
decode_noise_scale = decode_timestep
767+
elif not isinstance(decode_noise_scale, list):
768+
decode_noise_scale = [decode_noise_scale] * batch_size
769+
770+
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
771+
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
772+
:, None, None, None, None
773+
]
774+
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
775+
776+
video = self.vae.decode(latents, timestep, return_dict=False)[0]
757777
video = self.video_processor.postprocess_video(video, output_type=output_type)
758778

759779
# Offload all models

0 commit comments

Comments
 (0)