Skip to content

Commit c586b4b

Browse files
committed
replace F.pad by built-in padding in Conv3D
1 parent c5ce24f commit c586b4b

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def __init__(
103103
self.width_pad = width_pad
104104
self.time_pad = time_pad
105105
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
106+
self.const_padding_conv3d = (0, self.width_pad, self.height_pad)
106107

107108
self.temporal_dim = 2
108109
self.time_kernel_size = time_kernel_size
@@ -115,6 +116,8 @@ def __init__(
115116
kernel_size=kernel_size,
116117
stride=stride,
117118
dilation=dilation,
119+
padding = 0 if self.pad_mode == 'replicate' else self.const_padding_conv3d,
120+
padding_mode = 'zeros',
118121
)
119122

120123
def fake_context_parallel_forward(
@@ -135,9 +138,7 @@ def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = Non
135138
if self.pad_mode == "replicate":
136139
conv_cache = None
137140
else:
138-
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
139141
conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
140-
inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
141142

142143
output = self.conv(inputs)
143144
return output, conv_cache

0 commit comments

Comments
 (0)