@@ -103,6 +103,7 @@ def __init__(
103103        self .width_pad  =  width_pad 
104104        self .time_pad  =  time_pad 
105105        self .time_causal_padding  =  (width_pad , width_pad , height_pad , height_pad , time_pad , 0 )
106+         self .const_padding_conv3d  =   (0 , self .width_pad , self .height_pad ) 
106107
107108        self .temporal_dim  =  2 
108109        self .time_kernel_size  =  time_kernel_size 
@@ -115,6 +116,8 @@ def __init__(
115116            kernel_size = kernel_size ,
116117            stride = stride ,
117118            dilation = dilation ,
119+             padding  =  0  if  self .pad_mode  ==  'replicate'  else  self .const_padding_conv3d ,
120+             padding_mode  =  'zeros' ,
118121        )
119122
120123    def  fake_context_parallel_forward (
@@ -135,9 +138,7 @@ def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = Non
135138        if  self .pad_mode  ==  "replicate" :
136139            conv_cache  =  None 
137140        else :
138-             padding_2d  =  (self .width_pad , self .width_pad , self .height_pad , self .height_pad )
139141            conv_cache  =  inputs [:, :, - self .time_kernel_size  +  1  :].clone ()
140-             inputs  =  F .pad (inputs , padding_2d , mode = "constant" , value = 0 )
141142
142143        output  =  self .conv (inputs )
143144        return  output , conv_cache 
0 commit comments