Skip to content

Commit bea9e1b

Browse files
committed
refactor causal conv
1 parent 1b27c3a commit bea9e1b

File tree

1 file changed

+24
-66
lines changed

1 file changed

+24
-66
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py

Lines changed: 24 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from ..attention_processor import Attention, SpatialNorm
2929
from ..modeling_outputs import AutoencoderKLOutput
3030
from ..modeling_utils import ModelMixin
31-
from ..normalization import AdaGroupNorm, RMSNorm
31+
from ..normalization import AdaGroupNorm
3232
from .vae import BaseOutput, DecoderOutput, DiagonalGaussianDistribution
3333

3434

@@ -47,39 +47,36 @@ def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, device, batch_
4747

4848

4949
class CausalConv3d(nn.Module):
50-
"""
51-
Implements a causal 3D convolution layer where each position only depends on previous timesteps and current spatial
52-
locations. This maintains temporal causality in video generation tasks.
53-
"""
54-
5550
def __init__(
5651
self,
57-
chan_in,
58-
chan_out,
59-
kernel_size: Union[int, Tuple[int, int, int]],
52+
in_channels: int,
53+
out_channels: int,
54+
kernel_size: Union[int, Tuple[int, int, int]] = 3,
6055
stride: Union[int, Tuple[int, int, int]] = 1,
56+
padding: Union[int, Tuple[int, int, int]] = 0,
6157
dilation: Union[int, Tuple[int, int, int]] = 1,
62-
pad_mode="replicate",
63-
**kwargs,
64-
):
58+
bias: bool = True,
59+
pad_mode: str = "replicate",
60+
) -> None:
6561
super().__init__()
6662

63+
kernel_size = (kernel_size, kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
64+
6765
self.pad_mode = pad_mode
68-
padding = (
69-
kernel_size // 2,
70-
kernel_size // 2,
71-
kernel_size // 2,
72-
kernel_size // 2,
73-
kernel_size - 1,
66+
self.time_causal_padding = (
67+
kernel_size[0] // 2,
68+
kernel_size[0] // 2,
69+
kernel_size[1] // 2,
70+
kernel_size[1] // 2,
71+
kernel_size[2] - 1,
7472
0,
75-
) # W, H, T
76-
self.time_causal_padding = padding
73+
)
7774

78-
self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
75+
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
7976

80-
def forward(self, x):
81-
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
82-
return self.conv(x)
77+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
78+
hidden_states = F.pad(hidden_states, self.time_causal_padding, mode=self.pad_mode)
79+
return self.conv(hidden_states)
8380

8481

8582
class UpsampleCausal3D(nn.Module):
@@ -117,62 +114,25 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
117114

118115

119116
class DownsampleCausal3D(nn.Module):
120-
"""
121-
A 3D downsampling layer with an optional convolution.
122-
"""
123-
124117
def __init__(
125118
self,
126119
channels: int,
127-
use_conv: bool = False,
128120
out_channels: Optional[int] = None,
129121
padding: int = 1,
130-
name: str = "conv",
131122
kernel_size=3,
132-
norm_type=None,
133-
eps=None,
134-
elementwise_affine=None,
135123
bias=True,
136124
stride=2,
137125
):
138126
super().__init__()
139-
self.channels = channels
140-
self.out_channels = out_channels or channels
141-
self.use_conv = use_conv
142-
self.padding = padding
143-
stride = stride
144-
self.name = name
145-
146-
if norm_type == "ln_norm":
147-
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
148-
elif norm_type == "rms_norm":
149-
self.norm = RMSNorm(channels, eps, elementwise_affine)
150-
elif norm_type is None:
151-
self.norm = None
152-
else:
153-
raise ValueError(f"unknown norm_type: {norm_type}")
154-
155-
if use_conv:
156-
conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, bias=bias)
157-
else:
158-
raise NotImplementedError
159127

160-
if name == "conv":
161-
self.Conv2d_0 = conv
162-
self.conv = conv
163-
elif name == "Conv2d_0":
164-
self.conv = conv
165-
else:
166-
self.conv = conv
128+
out_channels = out_channels or channels
167129

168-
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
169-
assert hidden_states.shape[1] == self.channels
130+
self.conv = CausalConv3d(channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias)
170131

132+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
171133
if self.norm is not None:
172134
hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
173135

174-
assert hidden_states.shape[1] == self.channels
175-
176136
hidden_states = self.conv(hidden_states)
177137

178138
return hidden_states
@@ -456,10 +416,8 @@ def __init__(
456416
[
457417
DownsampleCausal3D(
458418
out_channels,
459-
use_conv=True,
460419
out_channels=out_channels,
461420
padding=downsample_padding,
462-
name="op",
463421
stride=downsample_stride,
464422
)
465423
]

0 commit comments

Comments
 (0)