2929from .vae import DecoderOutput , DiagonalGaussianDistribution
3030
3131
32- class LTXCausalConv3d (nn .Module ):
32+ class LTXVideoCausalConv3d (nn .Module ):
3333 def __init__ (
3434 self ,
3535 in_channels : int ,
@@ -80,9 +80,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
8080 return hidden_states
8181
8282
83- class LTXResnetBlock3d (nn .Module ):
83+ class LTXVideoResnetBlock3d (nn .Module ):
8484 r"""
85- A 3D ResNet block used in the LTX model.
85+ A 3D ResNet block used in the LTXVideo model.
8686
8787 Args:
8888 in_channels (`int`):
@@ -120,21 +120,21 @@ def __init__(
120120 self .nonlinearity = get_activation (non_linearity )
121121
122122 self .norm1 = RMSNorm (in_channels , eps = 1e-8 , elementwise_affine = elementwise_affine )
123- self .conv1 = LTXCausalConv3d (
123+ self .conv1 = LTXVideoCausalConv3d (
124124 in_channels = in_channels , out_channels = out_channels , kernel_size = 3 , is_causal = is_causal
125125 )
126126
127127 self .norm2 = RMSNorm (out_channels , eps = 1e-8 , elementwise_affine = elementwise_affine )
128128 self .dropout = nn .Dropout (dropout )
129- self .conv2 = LTXCausalConv3d (
129+ self .conv2 = LTXVideoCausalConv3d (
130130 in_channels = out_channels , out_channels = out_channels , kernel_size = 3 , is_causal = is_causal
131131 )
132132
133133 self .norm3 = None
134134 self .conv_shortcut = None
135135 if in_channels != out_channels :
136136 self .norm3 = nn .LayerNorm (in_channels , eps = eps , elementwise_affine = True , bias = True )
137- self .conv_shortcut = LTXCausalConv3d (
137+ self .conv_shortcut = LTXVideoCausalConv3d (
138138 in_channels = in_channels , out_channels = out_channels , kernel_size = 1 , stride = 1 , is_causal = is_causal
139139 )
140140
@@ -196,7 +196,7 @@ def forward(
196196 return hidden_states
197197
198198
199- class LTXUpsampler3d (nn .Module ):
199+ class LTXVideoUpsampler3d (nn .Module ):
200200 def __init__ (
201201 self ,
202202 in_channels : int ,
@@ -213,7 +213,7 @@ def __init__(
213213
214214 out_channels = (in_channels * stride [0 ] * stride [1 ] * stride [2 ]) // upscale_factor
215215
216- self .conv = LTXCausalConv3d (
216+ self .conv = LTXVideoCausalConv3d (
217217 in_channels = in_channels ,
218218 out_channels = out_channels ,
219219 kernel_size = 3 ,
@@ -246,9 +246,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
246246 return hidden_states
247247
248248
249- class LTXDownBlock3D (nn .Module ):
249+ class LTXVideoDownBlock3D (nn .Module ):
250250 r"""
251- Down block used in the LTX model.
251+ Down block used in the LTXVideo model.
252252
253253 Args:
254254 in_channels (`int`):
@@ -290,7 +290,7 @@ def __init__(
290290 resnets = []
291291 for _ in range (num_layers ):
292292 resnets .append (
293- LTXResnetBlock3d (
293+ LTXVideoResnetBlock3d (
294294 in_channels = in_channels ,
295295 out_channels = in_channels ,
296296 dropout = dropout ,
@@ -305,7 +305,7 @@ def __init__(
305305 if spatio_temporal_scale :
306306 self .downsamplers = nn .ModuleList (
307307 [
308- LTXCausalConv3d (
308+ LTXVideoCausalConv3d (
309309 in_channels = in_channels ,
310310 out_channels = in_channels ,
311311 kernel_size = 3 ,
@@ -317,7 +317,7 @@ def __init__(
317317
318318 self .conv_out = None
319319 if in_channels != out_channels :
320- self .conv_out = LTXResnetBlock3d (
320+ self .conv_out = LTXVideoResnetBlock3d (
321321 in_channels = in_channels ,
322322 out_channels = out_channels ,
323323 dropout = dropout ,
@@ -362,9 +362,9 @@ def create_forward(*inputs):
362362
363363
364364# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d
365- class LTXMidBlock3d (nn .Module ):
365+ class LTXVideoMidBlock3d (nn .Module ):
366366 r"""
367- A middle block used in the LTX model.
367+ A middle block used in the LTXVideo model.
368368
369369 Args:
370370 in_channels (`int`):
@@ -403,7 +403,7 @@ def __init__(
403403 resnets = []
404404 for _ in range (num_layers ):
405405 resnets .append (
406- LTXResnetBlock3d (
406+ LTXVideoResnetBlock3d (
407407 in_channels = in_channels ,
408408 out_channels = in_channels ,
409409 dropout = dropout ,
@@ -454,9 +454,9 @@ def create_forward(*inputs):
454454 return hidden_states
455455
456456
457- class LTXUpBlock3d (nn .Module ):
457+ class LTXVideoUpBlock3d (nn .Module ):
458458 r"""
459- Up block used in the LTX model.
459+ Up block used in the LTXVideo model.
460460
461461 Args:
462462 in_channels (`int`):
@@ -505,7 +505,7 @@ def __init__(
505505
506506 self .conv_in = None
507507 if in_channels != out_channels :
508- self .conv_in = LTXResnetBlock3d (
508+ self .conv_in = LTXVideoResnetBlock3d (
509509 in_channels = in_channels ,
510510 out_channels = out_channels ,
511511 dropout = dropout ,
@@ -520,7 +520,7 @@ def __init__(
520520 if spatio_temporal_scale :
521521 self .upsamplers = nn .ModuleList (
522522 [
523- LTXUpsampler3d (
523+ LTXVideoUpsampler3d (
524524 out_channels * upscale_factor ,
525525 stride = (2 , 2 , 2 ),
526526 is_causal = is_causal ,
@@ -533,7 +533,7 @@ def __init__(
533533 resnets = []
534534 for _ in range (num_layers ):
535535 resnets .append (
536- LTXResnetBlock3d (
536+ LTXVideoResnetBlock3d (
537537 in_channels = out_channels ,
538538 out_channels = out_channels ,
539539 dropout = dropout ,
@@ -589,9 +589,9 @@ def create_forward(*inputs):
589589 return hidden_states
590590
591591
592- class LTXEncoder3d (nn .Module ):
592+ class LTXVideoEncoder3d (nn .Module ):
593593 r"""
594- The `LTXEncoder3D ` layer of a variational autoencoder that encodes input video samples to its latent
594+ The `LTXVideoEncoder3d ` layer of a variational autoencoder that encodes input video samples to its latent
595595 representation.
596596
597597 Args:
@@ -635,7 +635,7 @@ def __init__(
635635
636636 output_channel = block_out_channels [0 ]
637637
638- self .conv_in = LTXCausalConv3d (
638+ self .conv_in = LTXVideoCausalConv3d (
639639 in_channels = self .in_channels ,
640640 out_channels = output_channel ,
641641 kernel_size = 3 ,
@@ -650,7 +650,7 @@ def __init__(
650650 input_channel = output_channel
651651 output_channel = block_out_channels [i + 1 ] if i + 1 < num_block_out_channels else block_out_channels [i ]
652652
653- down_block = LTXDownBlock3D (
653+ down_block = LTXVideoDownBlock3D (
654654 in_channels = input_channel ,
655655 out_channels = output_channel ,
656656 num_layers = layers_per_block [i ],
@@ -662,7 +662,7 @@ def __init__(
662662 self .down_blocks .append (down_block )
663663
664664 # mid block
665- self .mid_block = LTXMidBlock3d (
665+ self .mid_block = LTXVideoMidBlock3d (
666666 in_channels = output_channel ,
667667 num_layers = layers_per_block [- 1 ],
668668 resnet_eps = resnet_norm_eps ,
@@ -672,14 +672,14 @@ def __init__(
672672 # out
673673 self .norm_out = RMSNorm (out_channels , eps = 1e-8 , elementwise_affine = False )
674674 self .conv_act = nn .SiLU ()
675- self .conv_out = LTXCausalConv3d (
675+ self .conv_out = LTXVideoCausalConv3d (
676676 in_channels = output_channel , out_channels = out_channels + 1 , kernel_size = 3 , stride = 1 , is_causal = is_causal
677677 )
678678
679679 self .gradient_checkpointing = False
680680
681681 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
682- r"""The forward method of the `LTXEncoder3D ` class."""
682+ r"""The forward method of the `LTXVideoEncoder3d ` class."""
683683
684684 p = self .patch_size
685685 p_t = self .patch_size_t
@@ -725,9 +725,10 @@ def create_forward(*inputs):
725725 return hidden_states
726726
727727
728- class LTXDecoder3d (nn .Module ):
728+ class LTXVideoDecoder3d (nn .Module ):
729729 r"""
730- The `LTXDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output sample.
730+ The `LTXVideoDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output
731+ sample.
731732
732733 Args:
733734 in_channels (`int`, defaults to 128):
@@ -782,11 +783,11 @@ def __init__(
782783 upsample_factor = tuple (reversed (upsample_factor ))
783784 output_channel = block_out_channels [0 ]
784785
785- self .conv_in = LTXCausalConv3d (
786+ self .conv_in = LTXVideoCausalConv3d (
786787 in_channels = in_channels , out_channels = output_channel , kernel_size = 3 , stride = 1 , is_causal = is_causal
787788 )
788789
789- self .mid_block = LTXMidBlock3d (
790+ self .mid_block = LTXVideoMidBlock3d (
790791 in_channels = output_channel ,
791792 num_layers = layers_per_block [0 ],
792793 resnet_eps = resnet_norm_eps ,
@@ -802,7 +803,7 @@ def __init__(
802803 input_channel = output_channel // upsample_factor [i ]
803804 output_channel = block_out_channels [i ] // upsample_factor [i ]
804805
805- up_block = LTXUpBlock3d (
806+ up_block = LTXVideoUpBlock3d (
806807 in_channels = input_channel ,
807808 out_channels = output_channel ,
808809 num_layers = layers_per_block [i + 1 ],
@@ -820,7 +821,7 @@ def __init__(
820821 # out
821822 self .norm_out = RMSNorm (out_channels , eps = 1e-8 , elementwise_affine = False )
822823 self .conv_act = nn .SiLU ()
823- self .conv_out = LTXCausalConv3d (
824+ self .conv_out = LTXVideoCausalConv3d (
824825 in_channels = output_channel , out_channels = self .out_channels , kernel_size = 3 , stride = 1 , is_causal = is_causal
825826 )
826827
@@ -951,7 +952,7 @@ def __init__(
951952 ) -> None :
952953 super ().__init__ ()
953954
954- self .encoder = LTXEncoder3d (
955+ self .encoder = LTXVideoEncoder3d (
955956 in_channels = in_channels ,
956957 out_channels = latent_channels ,
957958 block_out_channels = block_out_channels ,
@@ -962,7 +963,7 @@ def __init__(
962963 resnet_norm_eps = resnet_norm_eps ,
963964 is_causal = encoder_causal ,
964965 )
965- self .decoder = LTXDecoder3d (
966+ self .decoder = LTXVideoDecoder3d (
966967 in_channels = latent_channels ,
967968 out_channels = out_channels ,
968969 block_out_channels = decoder_block_out_channels ,
@@ -1015,7 +1016,7 @@ def __init__(
10151016 self .tile_sample_stride_width = 448
10161017
10171018 def _set_gradient_checkpointing (self , module , value = False ):
1018- if isinstance (module , (LTXEncoder3d , LTXDecoder3d )):
1019+ if isinstance (module , (LTXVideoEncoder3d , LTXVideoDecoder3d )):
10191020 module .gradient_checkpointing = value
10201021
10211022 def enable_tiling (
0 commit comments