Skip to content

Commit 9d27df8

Browse files
a-r-r-o-wDN6
andauthored
Rename LTX blocks and docs title (#10213)
* rename blocks and docs * fix docs --------- Co-authored-by: Dhruv Nair <[email protected]>
1 parent 055d955 commit 9d27df8

File tree

5 files changed

+49
-48
lines changed

5 files changed

+49
-48
lines changed

docs/source/en/_toctree.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@
429429
- local: api/pipelines/ledits_pp
430430
title: LEDITS++
431431
- local: api/pipelines/ltx_video
432-
title: LTX
432+
title: LTXVideo
433433
- local: api/pipelines/lumina
434434
title: Lumina-T2X
435435
- local: api/pipelines/marigold

docs/source/en/api/models/autoencoderkl_ltx_video.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
1818
```python
1919
from diffusers import AutoencoderKLLTXVideo
2020

21-
vae = AutoencoderKLLTXVideo.from_pretrained("TODO/TODO", subfolder="vae", torch_dtype=torch.float32).to("cuda")
21+
vae = AutoencoderKLLTXVideo.from_pretrained("Lightricks/LTX-Video", subfolder="vae", torch_dtype=torch.float32).to("cuda")
2222
```
2323

2424
## AutoencoderKLLTXVideo

docs/source/en/api/models/ltx_video_transformer3d.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
1818
```python
1919
from diffusers import LTXVideoTransformer3DModel
2020

21-
transformer = LTXVideoTransformer3DModel.from_pretrained("TODO/TODO", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
21+
transformer = LTXVideoTransformer3DModel.from_pretrained("Lightricks/LTX-Video", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
2222
```
2323

2424
## LTXVideoTransformer3DModel

src/diffusers/models/autoencoders/autoencoder_kl_ltx.py

Lines changed: 38 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from .vae import DecoderOutput, DiagonalGaussianDistribution
2929

3030

31-
class LTXCausalConv3d(nn.Module):
31+
class LTXVideoCausalConv3d(nn.Module):
3232
def __init__(
3333
self,
3434
in_channels: int,
@@ -79,9 +79,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
7979
return hidden_states
8080

8181

82-
class LTXResnetBlock3d(nn.Module):
82+
class LTXVideoResnetBlock3d(nn.Module):
8383
r"""
84-
A 3D ResNet block used in the LTX model.
84+
A 3D ResNet block used in the LTXVideo model.
8585
8686
Args:
8787
in_channels (`int`):
@@ -117,21 +117,21 @@ def __init__(
117117
self.nonlinearity = get_activation(non_linearity)
118118

119119
self.norm1 = RMSNorm(in_channels, eps=1e-8, elementwise_affine=elementwise_affine)
120-
self.conv1 = LTXCausalConv3d(
120+
self.conv1 = LTXVideoCausalConv3d(
121121
in_channels=in_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal
122122
)
123123

124124
self.norm2 = RMSNorm(out_channels, eps=1e-8, elementwise_affine=elementwise_affine)
125125
self.dropout = nn.Dropout(dropout)
126-
self.conv2 = LTXCausalConv3d(
126+
self.conv2 = LTXVideoCausalConv3d(
127127
in_channels=out_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal
128128
)
129129

130130
self.norm3 = None
131131
self.conv_shortcut = None
132132
if in_channels != out_channels:
133133
self.norm3 = nn.LayerNorm(in_channels, eps=eps, elementwise_affine=True, bias=True)
134-
self.conv_shortcut = LTXCausalConv3d(
134+
self.conv_shortcut = LTXVideoCausalConv3d(
135135
in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, is_causal=is_causal
136136
)
137137

@@ -157,7 +157,7 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
157157
return hidden_states
158158

159159

160-
class LTXUpsampler3d(nn.Module):
160+
class LTXVideoUpsampler3d(nn.Module):
161161
def __init__(
162162
self,
163163
in_channels: int,
@@ -170,7 +170,7 @@ def __init__(
170170

171171
out_channels = in_channels * stride[0] * stride[1] * stride[2]
172172

173-
self.conv = LTXCausalConv3d(
173+
self.conv = LTXVideoCausalConv3d(
174174
in_channels=in_channels,
175175
out_channels=out_channels,
176176
kernel_size=3,
@@ -191,9 +191,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
191191
return hidden_states
192192

193193

194-
class LTXDownBlock3D(nn.Module):
194+
class LTXVideoDownBlock3D(nn.Module):
195195
r"""
196-
Down block used in the LTX model.
196+
Down block used in the LTXVideo model.
197197
198198
Args:
199199
in_channels (`int`):
@@ -235,7 +235,7 @@ def __init__(
235235
resnets = []
236236
for _ in range(num_layers):
237237
resnets.append(
238-
LTXResnetBlock3d(
238+
LTXVideoResnetBlock3d(
239239
in_channels=in_channels,
240240
out_channels=in_channels,
241241
dropout=dropout,
@@ -250,7 +250,7 @@ def __init__(
250250
if spatio_temporal_scale:
251251
self.downsamplers = nn.ModuleList(
252252
[
253-
LTXCausalConv3d(
253+
LTXVideoCausalConv3d(
254254
in_channels=in_channels,
255255
out_channels=in_channels,
256256
kernel_size=3,
@@ -262,7 +262,7 @@ def __init__(
262262

263263
self.conv_out = None
264264
if in_channels != out_channels:
265-
self.conv_out = LTXResnetBlock3d(
265+
self.conv_out = LTXVideoResnetBlock3d(
266266
in_channels=in_channels,
267267
out_channels=out_channels,
268268
dropout=dropout,
@@ -300,9 +300,9 @@ def create_forward(*inputs):
300300

301301

302302
# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d
303-
class LTXMidBlock3d(nn.Module):
303+
class LTXVideoMidBlock3d(nn.Module):
304304
r"""
305-
A middle block used in the LTX model.
305+
A middle block used in the LTXVideo model.
306306
307307
Args:
308308
in_channels (`int`):
@@ -335,7 +335,7 @@ def __init__(
335335
resnets = []
336336
for _ in range(num_layers):
337337
resnets.append(
338-
LTXResnetBlock3d(
338+
LTXVideoResnetBlock3d(
339339
in_channels=in_channels,
340340
out_channels=in_channels,
341341
dropout=dropout,
@@ -367,9 +367,9 @@ def create_forward(*inputs):
367367
return hidden_states
368368

369369

370-
class LTXUpBlock3d(nn.Module):
370+
class LTXVideoUpBlock3d(nn.Module):
371371
r"""
372-
Up block used in the LTX model.
372+
Up block used in the LTXVideo model.
373373
374374
Args:
375375
in_channels (`int`):
@@ -410,7 +410,7 @@ def __init__(
410410

411411
self.conv_in = None
412412
if in_channels != out_channels:
413-
self.conv_in = LTXResnetBlock3d(
413+
self.conv_in = LTXVideoResnetBlock3d(
414414
in_channels=in_channels,
415415
out_channels=out_channels,
416416
dropout=dropout,
@@ -421,12 +421,12 @@ def __init__(
421421

422422
self.upsamplers = None
423423
if spatio_temporal_scale:
424-
self.upsamplers = nn.ModuleList([LTXUpsampler3d(out_channels, stride=(2, 2, 2), is_causal=is_causal)])
424+
self.upsamplers = nn.ModuleList([LTXVideoUpsampler3d(out_channels, stride=(2, 2, 2), is_causal=is_causal)])
425425

426426
resnets = []
427427
for _ in range(num_layers):
428428
resnets.append(
429-
LTXResnetBlock3d(
429+
LTXVideoResnetBlock3d(
430430
in_channels=out_channels,
431431
out_channels=out_channels,
432432
dropout=dropout,
@@ -463,9 +463,9 @@ def create_forward(*inputs):
463463
return hidden_states
464464

465465

466-
class LTXEncoder3d(nn.Module):
466+
class LTXVideoEncoder3d(nn.Module):
467467
r"""
468-
The `LTXEncoder3D` layer of a variational autoencoder that encodes input video samples to its latent
468+
The `LTXVideoEncoder3d` layer of a variational autoencoder that encodes input video samples to its latent
469469
representation.
470470
471471
Args:
@@ -509,7 +509,7 @@ def __init__(
509509

510510
output_channel = block_out_channels[0]
511511

512-
self.conv_in = LTXCausalConv3d(
512+
self.conv_in = LTXVideoCausalConv3d(
513513
in_channels=self.in_channels,
514514
out_channels=output_channel,
515515
kernel_size=3,
@@ -524,7 +524,7 @@ def __init__(
524524
input_channel = output_channel
525525
output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i]
526526

527-
down_block = LTXDownBlock3D(
527+
down_block = LTXVideoDownBlock3D(
528528
in_channels=input_channel,
529529
out_channels=output_channel,
530530
num_layers=layers_per_block[i],
@@ -536,7 +536,7 @@ def __init__(
536536
self.down_blocks.append(down_block)
537537

538538
# mid block
539-
self.mid_block = LTXMidBlock3d(
539+
self.mid_block = LTXVideoMidBlock3d(
540540
in_channels=output_channel,
541541
num_layers=layers_per_block[-1],
542542
resnet_eps=resnet_norm_eps,
@@ -546,14 +546,14 @@ def __init__(
546546
# out
547547
self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False)
548548
self.conv_act = nn.SiLU()
549-
self.conv_out = LTXCausalConv3d(
549+
self.conv_out = LTXVideoCausalConv3d(
550550
in_channels=output_channel, out_channels=out_channels + 1, kernel_size=3, stride=1, is_causal=is_causal
551551
)
552552

553553
self.gradient_checkpointing = False
554554

555555
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
556-
r"""The forward method of the `LTXEncoder3D` class."""
556+
r"""The forward method of the `LTXVideoEncoder3d` class."""
557557

558558
p = self.patch_size
559559
p_t = self.patch_size_t
@@ -599,9 +599,10 @@ def create_forward(*inputs):
599599
return hidden_states
600600

601601

602-
class LTXDecoder3d(nn.Module):
602+
class LTXVideoDecoder3d(nn.Module):
603603
r"""
604-
The `LTXDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output sample.
604+
The `LTXVideoDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output
605+
sample.
605606
606607
Args:
607608
in_channels (`int`, defaults to 128):
@@ -647,11 +648,11 @@ def __init__(
647648
layers_per_block = tuple(reversed(layers_per_block))
648649
output_channel = block_out_channels[0]
649650

650-
self.conv_in = LTXCausalConv3d(
651+
self.conv_in = LTXVideoCausalConv3d(
651652
in_channels=in_channels, out_channels=output_channel, kernel_size=3, stride=1, is_causal=is_causal
652653
)
653654

654-
self.mid_block = LTXMidBlock3d(
655+
self.mid_block = LTXVideoMidBlock3d(
655656
in_channels=output_channel, num_layers=layers_per_block[0], resnet_eps=resnet_norm_eps, is_causal=is_causal
656657
)
657658

@@ -662,7 +663,7 @@ def __init__(
662663
input_channel = output_channel
663664
output_channel = block_out_channels[i]
664665

665-
up_block = LTXUpBlock3d(
666+
up_block = LTXVideoUpBlock3d(
666667
in_channels=input_channel,
667668
out_channels=output_channel,
668669
num_layers=layers_per_block[i + 1],
@@ -676,7 +677,7 @@ def __init__(
676677
# out
677678
self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False)
678679
self.conv_act = nn.SiLU()
679-
self.conv_out = LTXCausalConv3d(
680+
self.conv_out = LTXVideoCausalConv3d(
680681
in_channels=output_channel, out_channels=self.out_channels, kernel_size=3, stride=1, is_causal=is_causal
681682
)
682683

@@ -777,7 +778,7 @@ def __init__(
777778
) -> None:
778779
super().__init__()
779780

780-
self.encoder = LTXEncoder3d(
781+
self.encoder = LTXVideoEncoder3d(
781782
in_channels=in_channels,
782783
out_channels=latent_channels,
783784
block_out_channels=block_out_channels,
@@ -788,7 +789,7 @@ def __init__(
788789
resnet_norm_eps=resnet_norm_eps,
789790
is_causal=encoder_causal,
790791
)
791-
self.decoder = LTXDecoder3d(
792+
self.decoder = LTXVideoDecoder3d(
792793
in_channels=latent_channels,
793794
out_channels=out_channels,
794795
block_out_channels=block_out_channels,
@@ -837,7 +838,7 @@ def __init__(
837838
self.tile_sample_stride_width = 448
838839

839840
def _set_gradient_checkpointing(self, module, value=False):
840-
if isinstance(module, (LTXEncoder3d, LTXDecoder3d)):
841+
if isinstance(module, (LTXVideoEncoder3d, LTXVideoDecoder3d)):
841842
module.gradient_checkpointing = value
842843

843844
def enable_tiling(

src/diffusers/models/transformers/transformer_ltx.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3636

3737

38-
class LTXAttentionProcessor2_0:
38+
class LTXVideoAttentionProcessor2_0:
3939
r"""
4040
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
4141
used in the LTX model. It applies a normalization layer and rotary embedding on the query and key vector.
@@ -44,7 +44,7 @@ class LTXAttentionProcessor2_0:
4444
def __init__(self):
4545
if not hasattr(F, "scaled_dot_product_attention"):
4646
raise ImportError(
47-
"LTXAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
47+
"LTXVideoAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
4848
)
4949

5050
def __call__(
@@ -92,7 +92,7 @@ def __call__(
9292
return hidden_states
9393

9494

95-
class LTXRotaryPosEmbed(nn.Module):
95+
class LTXVideoRotaryPosEmbed(nn.Module):
9696
def __init__(
9797
self,
9898
dim: int,
@@ -164,7 +164,7 @@ def forward(
164164

165165

166166
@maybe_allow_in_graph
167-
class LTXTransformerBlock(nn.Module):
167+
class LTXVideoTransformerBlock(nn.Module):
168168
r"""
169169
Transformer block used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
170170
@@ -208,7 +208,7 @@ def __init__(
208208
cross_attention_dim=None,
209209
out_bias=attention_out_bias,
210210
qk_norm=qk_norm,
211-
processor=LTXAttentionProcessor2_0(),
211+
processor=LTXVideoAttentionProcessor2_0(),
212212
)
213213

214214
self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
@@ -221,7 +221,7 @@ def __init__(
221221
bias=attention_bias,
222222
out_bias=attention_out_bias,
223223
qk_norm=qk_norm,
224-
processor=LTXAttentionProcessor2_0(),
224+
processor=LTXVideoAttentionProcessor2_0(),
225225
)
226226

227227
self.ff = FeedForward(dim, activation_fn=activation_fn)
@@ -327,7 +327,7 @@ def __init__(
327327

328328
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
329329

330-
self.rope = LTXRotaryPosEmbed(
330+
self.rope = LTXVideoRotaryPosEmbed(
331331
dim=inner_dim,
332332
base_num_frames=20,
333333
base_height=2048,
@@ -339,7 +339,7 @@ def __init__(
339339

340340
self.transformer_blocks = nn.ModuleList(
341341
[
342-
LTXTransformerBlock(
342+
LTXVideoTransformerBlock(
343343
dim=inner_dim,
344344
num_attention_heads=num_attention_heads,
345345
attention_head_dim=attention_head_dim,

0 commit comments

Comments
 (0)