2828from ..attention_processor import Attention , SpatialNorm
2929from ..modeling_outputs import AutoencoderKLOutput
3030from ..modeling_utils import ModelMixin
31- from ..normalization import AdaGroupNorm , RMSNorm
31+ from ..normalization import AdaGroupNorm
3232from .vae import BaseOutput , DecoderOutput , DiagonalGaussianDistribution
3333
3434
@@ -47,39 +47,36 @@ def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, device, batch_
4747
4848
4949class CausalConv3d (nn .Module ):
50- """
51- Implements a causal 3D convolution layer where each position only depends on previous timesteps and current spatial
52- locations. This maintains temporal causality in video generation tasks.
53- """
54-
5550 def __init__ (
5651 self ,
57- chan_in ,
58- chan_out ,
59- kernel_size : Union [int , Tuple [int , int , int ]],
52+ in_channels : int ,
53+ out_channels : int ,
54+ kernel_size : Union [int , Tuple [int , int , int ]] = 3 ,
6055 stride : Union [int , Tuple [int , int , int ]] = 1 ,
56+ padding : Union [int , Tuple [int , int , int ]] = 0 ,
6157 dilation : Union [int , Tuple [int , int , int ]] = 1 ,
62- pad_mode = "replicate" ,
63- ** kwargs ,
64- ):
58+ bias : bool = True ,
59+ pad_mode : str = "replicate" ,
60+ ) -> None :
6561 super ().__init__ ()
6662
63+ kernel_size = (kernel_size , kernel_size , kernel_size ) if isinstance (kernel_size , int ) else kernel_size
64+
6765 self .pad_mode = pad_mode
68- padding = (
69- kernel_size // 2 ,
70- kernel_size // 2 ,
71- kernel_size // 2 ,
72- kernel_size // 2 ,
73- kernel_size - 1 ,
66+ self . time_causal_padding = (
67+ kernel_size [ 0 ] // 2 ,
68+ kernel_size [ 0 ] // 2 ,
69+ kernel_size [ 1 ] // 2 ,
70+ kernel_size [ 1 ] // 2 ,
71+ kernel_size [ 2 ] - 1 ,
7472 0 ,
75- ) # W, H, T
76- self .time_causal_padding = padding
73+ )
7774
78- self .conv = nn .Conv3d (chan_in , chan_out , kernel_size , stride = stride , dilation = dilation , ** kwargs )
75+ self .conv = nn .Conv3d (in_channels , out_channels , kernel_size , stride , padding , dilation , bias = bias )
7976
80- def forward (self , x ) :
81- x = F .pad (x , self .time_causal_padding , mode = self .pad_mode )
82- return self .conv (x )
77+ def forward (self , hidden_states : torch . Tensor ) -> torch . Tensor :
78+ hidden_states = F .pad (hidden_states , self .time_causal_padding , mode = self .pad_mode )
79+ return self .conv (hidden_states )
8380
8481
8582class UpsampleCausal3D (nn .Module ):
@@ -117,62 +114,25 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
117114
118115
119116class DownsampleCausal3D (nn .Module ):
120- """
121- A 3D downsampling layer with an optional convolution.
122- """
123-
124117 def __init__ (
125118 self ,
126119 channels : int ,
127- use_conv : bool = False ,
128120 out_channels : Optional [int ] = None ,
129121 padding : int = 1 ,
130- name : str = "conv" ,
131122 kernel_size = 3 ,
132- norm_type = None ,
133- eps = None ,
134- elementwise_affine = None ,
135123 bias = True ,
136124 stride = 2 ,
137125 ):
138126 super ().__init__ ()
139- self .channels = channels
140- self .out_channels = out_channels or channels
141- self .use_conv = use_conv
142- self .padding = padding
143- stride = stride
144- self .name = name
145-
146- if norm_type == "ln_norm" :
147- self .norm = nn .LayerNorm (channels , eps , elementwise_affine )
148- elif norm_type == "rms_norm" :
149- self .norm = RMSNorm (channels , eps , elementwise_affine )
150- elif norm_type is None :
151- self .norm = None
152- else :
153- raise ValueError (f"unknown norm_type: { norm_type } " )
154-
155- if use_conv :
156- conv = CausalConv3d (self .channels , self .out_channels , kernel_size = kernel_size , stride = stride , bias = bias )
157- else :
158- raise NotImplementedError
159127
160- if name == "conv" :
161- self .Conv2d_0 = conv
162- self .conv = conv
163- elif name == "Conv2d_0" :
164- self .conv = conv
165- else :
166- self .conv = conv
128+ out_channels = out_channels or channels
167129
168- def forward (self , hidden_states : torch .Tensor , scale : float = 1.0 ) -> torch .Tensor :
169- assert hidden_states .shape [1 ] == self .channels
130+ self .conv = CausalConv3d (channels , out_channels , kernel_size = kernel_size , stride = stride , bias = bias )
170131
132+ def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
171133 if self .norm is not None :
172134 hidden_states = self .norm (hidden_states .permute (0 , 2 , 3 , 1 )).permute (0 , 3 , 1 , 2 )
173135
174- assert hidden_states .shape [1 ] == self .channels
175-
176136 hidden_states = self .conv (hidden_states )
177137
178138 return hidden_states
@@ -456,10 +416,8 @@ def __init__(
456416 [
457417 DownsampleCausal3D (
458418 out_channels ,
459- use_conv = True ,
460419 out_channels = out_channels ,
461420 padding = downsample_padding ,
462- name = "op" ,
463421 stride = downsample_stride ,
464422 )
465423 ]
0 commit comments