Skip to content

Commit 4325449

Browse files
authored
Merge branch 'main' into ltxv-0.9.1-integration
2 parents 7b412c5 + 9d27df8 commit 4325449

File tree

10 files changed

+184
-167
lines changed

10 files changed

+184
-167
lines changed

docs/source/en/_toctree.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@
429429
- local: api/pipelines/ledits_pp
430430
title: LEDITS++
431431
- local: api/pipelines/ltx_video
432-
title: LTX
432+
title: LTXVideo
433433
- local: api/pipelines/lumina
434434
title: Lumina-T2X
435435
- local: api/pipelines/marigold

docs/source/en/api/models/autoencoderkl_ltx_video.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
1818
```python
1919
from diffusers import AutoencoderKLLTXVideo
2020

21-
vae = AutoencoderKLLTXVideo.from_pretrained("TODO/TODO", subfolder="vae", torch_dtype=torch.float32).to("cuda")
21+
vae = AutoencoderKLLTXVideo.from_pretrained("Lightricks/LTX-Video", subfolder="vae", torch_dtype=torch.float32).to("cuda")
2222
```
2323

2424
## AutoencoderKLLTXVideo

docs/source/en/api/models/ltx_video_transformer3d.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
1818
```python
1919
from diffusers import LTXVideoTransformer3DModel
2020

21-
transformer = LTXVideoTransformer3DModel.from_pretrained("TODO/TODO", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
21+
transformer = LTXVideoTransformer3DModel.from_pretrained("Lightricks/LTX-Video", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
2222
```
2323

2424
## LTXVideoTransformer3DModel

src/diffusers/loaders/transformer_flux.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,3 +177,5 @@ def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
177177

178178
self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
179179
self.config.encoder_hid_dim_type = "ip_image_proj"
180+
181+
self.to(dtype=self.dtype, device=self.device)

src/diffusers/models/autoencoders/autoencoder_kl_ltx.py

Lines changed: 38 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from .vae import DecoderOutput, DiagonalGaussianDistribution
3030

3131

32-
class LTXCausalConv3d(nn.Module):
32+
class LTXVideoCausalConv3d(nn.Module):
3333
def __init__(
3434
self,
3535
in_channels: int,
@@ -80,9 +80,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
8080
return hidden_states
8181

8282

83-
class LTXResnetBlock3d(nn.Module):
83+
class LTXVideoResnetBlock3d(nn.Module):
8484
r"""
85-
A 3D ResNet block used in the LTX model.
85+
A 3D ResNet block used in the LTXVideo model.
8686
8787
Args:
8888
in_channels (`int`):
@@ -120,21 +120,21 @@ def __init__(
120120
self.nonlinearity = get_activation(non_linearity)
121121

122122
self.norm1 = RMSNorm(in_channels, eps=1e-8, elementwise_affine=elementwise_affine)
123-
self.conv1 = LTXCausalConv3d(
123+
self.conv1 = LTXVideoCausalConv3d(
124124
in_channels=in_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal
125125
)
126126

127127
self.norm2 = RMSNorm(out_channels, eps=1e-8, elementwise_affine=elementwise_affine)
128128
self.dropout = nn.Dropout(dropout)
129-
self.conv2 = LTXCausalConv3d(
129+
self.conv2 = LTXVideoCausalConv3d(
130130
in_channels=out_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal
131131
)
132132

133133
self.norm3 = None
134134
self.conv_shortcut = None
135135
if in_channels != out_channels:
136136
self.norm3 = nn.LayerNorm(in_channels, eps=eps, elementwise_affine=True, bias=True)
137-
self.conv_shortcut = LTXCausalConv3d(
137+
self.conv_shortcut = LTXVideoCausalConv3d(
138138
in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, is_causal=is_causal
139139
)
140140

@@ -196,7 +196,7 @@ def forward(
196196
return hidden_states
197197

198198

199-
class LTXUpsampler3d(nn.Module):
199+
class LTXVideoUpsampler3d(nn.Module):
200200
def __init__(
201201
self,
202202
in_channels: int,
@@ -213,7 +213,7 @@ def __init__(
213213

214214
out_channels = (in_channels * stride[0] * stride[1] * stride[2]) // upscale_factor
215215

216-
self.conv = LTXCausalConv3d(
216+
self.conv = LTXVideoCausalConv3d(
217217
in_channels=in_channels,
218218
out_channels=out_channels,
219219
kernel_size=3,
@@ -246,9 +246,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
246246
return hidden_states
247247

248248

249-
class LTXDownBlock3D(nn.Module):
249+
class LTXVideoDownBlock3D(nn.Module):
250250
r"""
251-
Down block used in the LTX model.
251+
Down block used in the LTXVideo model.
252252
253253
Args:
254254
in_channels (`int`):
@@ -290,7 +290,7 @@ def __init__(
290290
resnets = []
291291
for _ in range(num_layers):
292292
resnets.append(
293-
LTXResnetBlock3d(
293+
LTXVideoResnetBlock3d(
294294
in_channels=in_channels,
295295
out_channels=in_channels,
296296
dropout=dropout,
@@ -305,7 +305,7 @@ def __init__(
305305
if spatio_temporal_scale:
306306
self.downsamplers = nn.ModuleList(
307307
[
308-
LTXCausalConv3d(
308+
LTXVideoCausalConv3d(
309309
in_channels=in_channels,
310310
out_channels=in_channels,
311311
kernel_size=3,
@@ -317,7 +317,7 @@ def __init__(
317317

318318
self.conv_out = None
319319
if in_channels != out_channels:
320-
self.conv_out = LTXResnetBlock3d(
320+
self.conv_out = LTXVideoResnetBlock3d(
321321
in_channels=in_channels,
322322
out_channels=out_channels,
323323
dropout=dropout,
@@ -362,9 +362,9 @@ def create_forward(*inputs):
362362

363363

364364
# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d
365-
class LTXMidBlock3d(nn.Module):
365+
class LTXVideoMidBlock3d(nn.Module):
366366
r"""
367-
A middle block used in the LTX model.
367+
A middle block used in the LTXVideo model.
368368
369369
Args:
370370
in_channels (`int`):
@@ -403,7 +403,7 @@ def __init__(
403403
resnets = []
404404
for _ in range(num_layers):
405405
resnets.append(
406-
LTXResnetBlock3d(
406+
LTXVideoResnetBlock3d(
407407
in_channels=in_channels,
408408
out_channels=in_channels,
409409
dropout=dropout,
@@ -454,9 +454,9 @@ def create_forward(*inputs):
454454
return hidden_states
455455

456456

457-
class LTXUpBlock3d(nn.Module):
457+
class LTXVideoUpBlock3d(nn.Module):
458458
r"""
459-
Up block used in the LTX model.
459+
Up block used in the LTXVideo model.
460460
461461
Args:
462462
in_channels (`int`):
@@ -505,7 +505,7 @@ def __init__(
505505

506506
self.conv_in = None
507507
if in_channels != out_channels:
508-
self.conv_in = LTXResnetBlock3d(
508+
self.conv_in = LTXVideoResnetBlock3d(
509509
in_channels=in_channels,
510510
out_channels=out_channels,
511511
dropout=dropout,
@@ -520,7 +520,7 @@ def __init__(
520520
if spatio_temporal_scale:
521521
self.upsamplers = nn.ModuleList(
522522
[
523-
LTXUpsampler3d(
523+
LTXVideoUpsampler3d(
524524
out_channels * upscale_factor,
525525
stride=(2, 2, 2),
526526
is_causal=is_causal,
@@ -533,7 +533,7 @@ def __init__(
533533
resnets = []
534534
for _ in range(num_layers):
535535
resnets.append(
536-
LTXResnetBlock3d(
536+
LTXVideoResnetBlock3d(
537537
in_channels=out_channels,
538538
out_channels=out_channels,
539539
dropout=dropout,
@@ -589,9 +589,9 @@ def create_forward(*inputs):
589589
return hidden_states
590590

591591

592-
class LTXEncoder3d(nn.Module):
592+
class LTXVideoEncoder3d(nn.Module):
593593
r"""
594-
The `LTXEncoder3D` layer of a variational autoencoder that encodes input video samples to its latent
594+
The `LTXVideoEncoder3d` layer of a variational autoencoder that encodes input video samples to its latent
595595
representation.
596596
597597
Args:
@@ -635,7 +635,7 @@ def __init__(
635635

636636
output_channel = block_out_channels[0]
637637

638-
self.conv_in = LTXCausalConv3d(
638+
self.conv_in = LTXVideoCausalConv3d(
639639
in_channels=self.in_channels,
640640
out_channels=output_channel,
641641
kernel_size=3,
@@ -650,7 +650,7 @@ def __init__(
650650
input_channel = output_channel
651651
output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i]
652652

653-
down_block = LTXDownBlock3D(
653+
down_block = LTXVideoDownBlock3D(
654654
in_channels=input_channel,
655655
out_channels=output_channel,
656656
num_layers=layers_per_block[i],
@@ -662,7 +662,7 @@ def __init__(
662662
self.down_blocks.append(down_block)
663663

664664
# mid block
665-
self.mid_block = LTXMidBlock3d(
665+
self.mid_block = LTXVideoMidBlock3d(
666666
in_channels=output_channel,
667667
num_layers=layers_per_block[-1],
668668
resnet_eps=resnet_norm_eps,
@@ -672,14 +672,14 @@ def __init__(
672672
# out
673673
self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False)
674674
self.conv_act = nn.SiLU()
675-
self.conv_out = LTXCausalConv3d(
675+
self.conv_out = LTXVideoCausalConv3d(
676676
in_channels=output_channel, out_channels=out_channels + 1, kernel_size=3, stride=1, is_causal=is_causal
677677
)
678678

679679
self.gradient_checkpointing = False
680680

681681
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
682-
r"""The forward method of the `LTXEncoder3D` class."""
682+
r"""The forward method of the `LTXVideoEncoder3d` class."""
683683

684684
p = self.patch_size
685685
p_t = self.patch_size_t
@@ -725,9 +725,10 @@ def create_forward(*inputs):
725725
return hidden_states
726726

727727

728-
class LTXDecoder3d(nn.Module):
728+
class LTXVideoDecoder3d(nn.Module):
729729
r"""
730-
The `LTXDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output sample.
730+
The `LTXVideoDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output
731+
sample.
731732
732733
Args:
733734
in_channels (`int`, defaults to 128):
@@ -782,11 +783,11 @@ def __init__(
782783
upsample_factor = tuple(reversed(upsample_factor))
783784
output_channel = block_out_channels[0]
784785

785-
self.conv_in = LTXCausalConv3d(
786+
self.conv_in = LTXVideoCausalConv3d(
786787
in_channels=in_channels, out_channels=output_channel, kernel_size=3, stride=1, is_causal=is_causal
787788
)
788789

789-
self.mid_block = LTXMidBlock3d(
790+
self.mid_block = LTXVideoMidBlock3d(
790791
in_channels=output_channel,
791792
num_layers=layers_per_block[0],
792793
resnet_eps=resnet_norm_eps,
@@ -802,7 +803,7 @@ def __init__(
802803
input_channel = output_channel // upsample_factor[i]
803804
output_channel = block_out_channels[i] // upsample_factor[i]
804805

805-
up_block = LTXUpBlock3d(
806+
up_block = LTXVideoUpBlock3d(
806807
in_channels=input_channel,
807808
out_channels=output_channel,
808809
num_layers=layers_per_block[i + 1],
@@ -820,7 +821,7 @@ def __init__(
820821
# out
821822
self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False)
822823
self.conv_act = nn.SiLU()
823-
self.conv_out = LTXCausalConv3d(
824+
self.conv_out = LTXVideoCausalConv3d(
824825
in_channels=output_channel, out_channels=self.out_channels, kernel_size=3, stride=1, is_causal=is_causal
825826
)
826827

@@ -951,7 +952,7 @@ def __init__(
951952
) -> None:
952953
super().__init__()
953954

954-
self.encoder = LTXEncoder3d(
955+
self.encoder = LTXVideoEncoder3d(
955956
in_channels=in_channels,
956957
out_channels=latent_channels,
957958
block_out_channels=block_out_channels,
@@ -962,7 +963,7 @@ def __init__(
962963
resnet_norm_eps=resnet_norm_eps,
963964
is_causal=encoder_causal,
964965
)
965-
self.decoder = LTXDecoder3d(
966+
self.decoder = LTXVideoDecoder3d(
966967
in_channels=latent_channels,
967968
out_channels=out_channels,
968969
block_out_channels=decoder_block_out_channels,
@@ -1015,7 +1016,7 @@ def __init__(
10151016
self.tile_sample_stride_width = 448
10161017

10171018
def _set_gradient_checkpointing(self, module, value=False):
1018-
if isinstance(module, (LTXEncoder3d, LTXDecoder3d)):
1019+
if isinstance(module, (LTXVideoEncoder3d, LTXVideoDecoder3d)):
10191020
module.gradient_checkpointing = value
10201021

10211022
def enable_tiling(

src/diffusers/models/embeddings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -748,10 +748,10 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
748748
pos_embedding = self._get_positional_embeddings(
749749
height, width, pre_time_compression_frames, device=embeds.device
750750
)
751-
pos_embedding = pos_embedding.to(dtype=embeds.dtype)
752751
else:
753752
pos_embedding = self.pos_embedding
754753

754+
pos_embedding = pos_embedding.to(dtype=embeds.dtype)
755755
embeds = embeds + pos_embedding
756756

757757
return embeds

0 commit comments

Comments
 (0)