28
28
from .vae import DecoderOutput , DiagonalGaussianDistribution
29
29
30
30
31
- class LTXCausalConv3d (nn .Module ):
31
+ class LTXVideoCausalConv3d (nn .Module ):
32
32
def __init__ (
33
33
self ,
34
34
in_channels : int ,
@@ -79,9 +79,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
79
79
return hidden_states
80
80
81
81
82
- class LTXResnetBlock3d (nn .Module ):
82
+ class LTXVideoResnetBlock3d (nn .Module ):
83
83
r"""
84
- A 3D ResNet block used in the LTX model.
84
+ A 3D ResNet block used in the LTXVideo model.
85
85
86
86
Args:
87
87
in_channels (`int`):
@@ -117,21 +117,21 @@ def __init__(
117
117
self .nonlinearity = get_activation (non_linearity )
118
118
119
119
self .norm1 = RMSNorm (in_channels , eps = 1e-8 , elementwise_affine = elementwise_affine )
120
- self .conv1 = LTXCausalConv3d (
120
+ self .conv1 = LTXVideoCausalConv3d (
121
121
in_channels = in_channels , out_channels = out_channels , kernel_size = 3 , is_causal = is_causal
122
122
)
123
123
124
124
self .norm2 = RMSNorm (out_channels , eps = 1e-8 , elementwise_affine = elementwise_affine )
125
125
self .dropout = nn .Dropout (dropout )
126
- self .conv2 = LTXCausalConv3d (
126
+ self .conv2 = LTXVideoCausalConv3d (
127
127
in_channels = out_channels , out_channels = out_channels , kernel_size = 3 , is_causal = is_causal
128
128
)
129
129
130
130
self .norm3 = None
131
131
self .conv_shortcut = None
132
132
if in_channels != out_channels :
133
133
self .norm3 = nn .LayerNorm (in_channels , eps = eps , elementwise_affine = True , bias = True )
134
- self .conv_shortcut = LTXCausalConv3d (
134
+ self .conv_shortcut = LTXVideoCausalConv3d (
135
135
in_channels = in_channels , out_channels = out_channels , kernel_size = 1 , stride = 1 , is_causal = is_causal
136
136
)
137
137
@@ -157,7 +157,7 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
157
157
return hidden_states
158
158
159
159
160
- class LTXUpsampler3d (nn .Module ):
160
+ class LTXVideoUpsampler3d (nn .Module ):
161
161
def __init__ (
162
162
self ,
163
163
in_channels : int ,
@@ -170,7 +170,7 @@ def __init__(
170
170
171
171
out_channels = in_channels * stride [0 ] * stride [1 ] * stride [2 ]
172
172
173
- self .conv = LTXCausalConv3d (
173
+ self .conv = LTXVideoCausalConv3d (
174
174
in_channels = in_channels ,
175
175
out_channels = out_channels ,
176
176
kernel_size = 3 ,
@@ -191,9 +191,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
191
191
return hidden_states
192
192
193
193
194
- class LTXDownBlock3D (nn .Module ):
194
+ class LTXVideoDownBlock3D (nn .Module ):
195
195
r"""
196
- Down block used in the LTX model.
196
+ Down block used in the LTXVideo model.
197
197
198
198
Args:
199
199
in_channels (`int`):
@@ -235,7 +235,7 @@ def __init__(
235
235
resnets = []
236
236
for _ in range (num_layers ):
237
237
resnets .append (
238
- LTXResnetBlock3d (
238
+ LTXVideoResnetBlock3d (
239
239
in_channels = in_channels ,
240
240
out_channels = in_channels ,
241
241
dropout = dropout ,
@@ -250,7 +250,7 @@ def __init__(
250
250
if spatio_temporal_scale :
251
251
self .downsamplers = nn .ModuleList (
252
252
[
253
- LTXCausalConv3d (
253
+ LTXVideoCausalConv3d (
254
254
in_channels = in_channels ,
255
255
out_channels = in_channels ,
256
256
kernel_size = 3 ,
@@ -262,7 +262,7 @@ def __init__(
262
262
263
263
self .conv_out = None
264
264
if in_channels != out_channels :
265
- self .conv_out = LTXResnetBlock3d (
265
+ self .conv_out = LTXVideoResnetBlock3d (
266
266
in_channels = in_channels ,
267
267
out_channels = out_channels ,
268
268
dropout = dropout ,
@@ -300,9 +300,9 @@ def create_forward(*inputs):
300
300
301
301
302
302
# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d
303
- class LTXMidBlock3d (nn .Module ):
303
+ class LTXVideoMidBlock3d (nn .Module ):
304
304
r"""
305
- A middle block used in the LTX model.
305
+ A middle block used in the LTXVideo model.
306
306
307
307
Args:
308
308
in_channels (`int`):
@@ -335,7 +335,7 @@ def __init__(
335
335
resnets = []
336
336
for _ in range (num_layers ):
337
337
resnets .append (
338
- LTXResnetBlock3d (
338
+ LTXVideoResnetBlock3d (
339
339
in_channels = in_channels ,
340
340
out_channels = in_channels ,
341
341
dropout = dropout ,
@@ -367,9 +367,9 @@ def create_forward(*inputs):
367
367
return hidden_states
368
368
369
369
370
- class LTXUpBlock3d (nn .Module ):
370
+ class LTXVideoUpBlock3d (nn .Module ):
371
371
r"""
372
- Up block used in the LTX model.
372
+ Up block used in the LTXVideo model.
373
373
374
374
Args:
375
375
in_channels (`int`):
@@ -410,7 +410,7 @@ def __init__(
410
410
411
411
self .conv_in = None
412
412
if in_channels != out_channels :
413
- self .conv_in = LTXResnetBlock3d (
413
+ self .conv_in = LTXVideoResnetBlock3d (
414
414
in_channels = in_channels ,
415
415
out_channels = out_channels ,
416
416
dropout = dropout ,
@@ -421,12 +421,12 @@ def __init__(
421
421
422
422
self .upsamplers = None
423
423
if spatio_temporal_scale :
424
- self .upsamplers = nn .ModuleList ([LTXUpsampler3d (out_channels , stride = (2 , 2 , 2 ), is_causal = is_causal )])
424
+ self .upsamplers = nn .ModuleList ([LTXVideoUpsampler3d (out_channels , stride = (2 , 2 , 2 ), is_causal = is_causal )])
425
425
426
426
resnets = []
427
427
for _ in range (num_layers ):
428
428
resnets .append (
429
- LTXResnetBlock3d (
429
+ LTXVideoResnetBlock3d (
430
430
in_channels = out_channels ,
431
431
out_channels = out_channels ,
432
432
dropout = dropout ,
@@ -463,9 +463,9 @@ def create_forward(*inputs):
463
463
return hidden_states
464
464
465
465
466
- class LTXEncoder3d (nn .Module ):
466
+ class LTXVideoEncoder3d (nn .Module ):
467
467
r"""
468
- The `LTXEncoder3D ` layer of a variational autoencoder that encodes input video samples to its latent
468
+ The `LTXVideoEncoder3d ` layer of a variational autoencoder that encodes input video samples to its latent
469
469
representation.
470
470
471
471
Args:
@@ -509,7 +509,7 @@ def __init__(
509
509
510
510
output_channel = block_out_channels [0 ]
511
511
512
- self .conv_in = LTXCausalConv3d (
512
+ self .conv_in = LTXVideoCausalConv3d (
513
513
in_channels = self .in_channels ,
514
514
out_channels = output_channel ,
515
515
kernel_size = 3 ,
@@ -524,7 +524,7 @@ def __init__(
524
524
input_channel = output_channel
525
525
output_channel = block_out_channels [i + 1 ] if i + 1 < num_block_out_channels else block_out_channels [i ]
526
526
527
- down_block = LTXDownBlock3D (
527
+ down_block = LTXVideoDownBlock3D (
528
528
in_channels = input_channel ,
529
529
out_channels = output_channel ,
530
530
num_layers = layers_per_block [i ],
@@ -536,7 +536,7 @@ def __init__(
536
536
self .down_blocks .append (down_block )
537
537
538
538
# mid block
539
- self .mid_block = LTXMidBlock3d (
539
+ self .mid_block = LTXVideoMidBlock3d (
540
540
in_channels = output_channel ,
541
541
num_layers = layers_per_block [- 1 ],
542
542
resnet_eps = resnet_norm_eps ,
@@ -546,14 +546,14 @@ def __init__(
546
546
# out
547
547
self .norm_out = RMSNorm (out_channels , eps = 1e-8 , elementwise_affine = False )
548
548
self .conv_act = nn .SiLU ()
549
- self .conv_out = LTXCausalConv3d (
549
+ self .conv_out = LTXVideoCausalConv3d (
550
550
in_channels = output_channel , out_channels = out_channels + 1 , kernel_size = 3 , stride = 1 , is_causal = is_causal
551
551
)
552
552
553
553
self .gradient_checkpointing = False
554
554
555
555
def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
556
- r"""The forward method of the `LTXEncoder3D ` class."""
556
+ r"""The forward method of the `LTXVideoEncoder3d ` class."""
557
557
558
558
p = self .patch_size
559
559
p_t = self .patch_size_t
@@ -599,9 +599,10 @@ def create_forward(*inputs):
599
599
return hidden_states
600
600
601
601
602
- class LTXDecoder3d (nn .Module ):
602
+ class LTXVideoDecoder3d (nn .Module ):
603
603
r"""
604
- The `LTXDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output sample.
604
+ The `LTXVideoDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output
605
+ sample.
605
606
606
607
Args:
607
608
in_channels (`int`, defaults to 128):
@@ -647,11 +648,11 @@ def __init__(
647
648
layers_per_block = tuple (reversed (layers_per_block ))
648
649
output_channel = block_out_channels [0 ]
649
650
650
- self .conv_in = LTXCausalConv3d (
651
+ self .conv_in = LTXVideoCausalConv3d (
651
652
in_channels = in_channels , out_channels = output_channel , kernel_size = 3 , stride = 1 , is_causal = is_causal
652
653
)
653
654
654
- self .mid_block = LTXMidBlock3d (
655
+ self .mid_block = LTXVideoMidBlock3d (
655
656
in_channels = output_channel , num_layers = layers_per_block [0 ], resnet_eps = resnet_norm_eps , is_causal = is_causal
656
657
)
657
658
@@ -662,7 +663,7 @@ def __init__(
662
663
input_channel = output_channel
663
664
output_channel = block_out_channels [i ]
664
665
665
- up_block = LTXUpBlock3d (
666
+ up_block = LTXVideoUpBlock3d (
666
667
in_channels = input_channel ,
667
668
out_channels = output_channel ,
668
669
num_layers = layers_per_block [i + 1 ],
@@ -676,7 +677,7 @@ def __init__(
676
677
# out
677
678
self .norm_out = RMSNorm (out_channels , eps = 1e-8 , elementwise_affine = False )
678
679
self .conv_act = nn .SiLU ()
679
- self .conv_out = LTXCausalConv3d (
680
+ self .conv_out = LTXVideoCausalConv3d (
680
681
in_channels = output_channel , out_channels = self .out_channels , kernel_size = 3 , stride = 1 , is_causal = is_causal
681
682
)
682
683
@@ -777,7 +778,7 @@ def __init__(
777
778
) -> None :
778
779
super ().__init__ ()
779
780
780
- self .encoder = LTXEncoder3d (
781
+ self .encoder = LTXVideoEncoder3d (
781
782
in_channels = in_channels ,
782
783
out_channels = latent_channels ,
783
784
block_out_channels = block_out_channels ,
@@ -788,7 +789,7 @@ def __init__(
788
789
resnet_norm_eps = resnet_norm_eps ,
789
790
is_causal = encoder_causal ,
790
791
)
791
- self .decoder = LTXDecoder3d (
792
+ self .decoder = LTXVideoDecoder3d (
792
793
in_channels = latent_channels ,
793
794
out_channels = out_channels ,
794
795
block_out_channels = block_out_channels ,
@@ -837,7 +838,7 @@ def __init__(
837
838
self .tile_sample_stride_width = 448
838
839
839
840
def _set_gradient_checkpointing (self , module , value = False ):
840
- if isinstance (module , (LTXEncoder3d , LTXDecoder3d )):
841
+ if isinstance (module , (LTXVideoEncoder3d , LTXVideoDecoder3d )):
841
842
module .gradient_checkpointing = value
842
843
843
844
def enable_tiling (
0 commit comments