@@ -108,6 +108,20 @@ def integer_conv2d(self, x: torch.Tensor):
108108 _cudnn_enabled = torch .backends .cudnn .enabled
109109 torch .backends .cudnn .enabled = False
110110
111+ if x .numel () == 0 :
112+ N , _ , H , W = x .shape
113+ kernel_h , kernel_w = self .kernel_size
114+ stride_h , stride_w = self .stride
115+ pad_h , pad_w = self .padding
116+ dil_h , dil_w = self .dilation
117+
118+ H_out = (H + 2 * pad_h - dil_h * (kernel_h - 1 ) - 1 ) // stride_h + 1
119+ W_out = (W + 2 * pad_w - dil_w * (kernel_w - 1 ) - 1 ) // stride_w + 1
120+
121+ out_shape = (N , self .out_channels , H_out , W_out )
122+ torch .backends .cudnn .enabled = _cudnn_enabled
123+ return x .new_empty (out_shape ).to (_dtype )
124+
111125 ###### ADOPT VCMRS IMPLEMENTATION ######
112126 # Calculate factor
113127 fx = 1
@@ -226,6 +240,25 @@ def integer_transposeconv2d(self, x: torch.Tensor):
226240 _cudnn_enabled = torch .backends .cudnn .enabled
227241 torch .backends .cudnn .enabled = False
228242
243+ if x .numel () == 0 :
244+ N , _ , H , W = x .shape
245+ kernel_h , kernel_w = self .kernel_size
246+ stride_h , stride_w = self .stride
247+ pad_h , pad_w = self .padding
248+ dil_h , dil_w = self .dilation
249+ out_pad_h , out_pad_w = self .output_padding
250+
251+ H_out = (
252+ (H - 1 ) * stride_h - 2 * pad_h + dil_h * (kernel_h - 1 ) + out_pad_h + 1
253+ )
254+ W_out = (
255+ (W - 1 ) * stride_w - 2 * pad_w + dil_w * (kernel_w - 1 ) + out_pad_w + 1
256+ )
257+
258+ out_shape = (N , self .out_channels , H_out , W_out )
259+ torch .backends .cudnn .enabled = _cudnn_enabled
260+ return x .new_empty (out_shape ).to (_dtype )
261+
229262 ###### ADOPT VCMRS IMPLEMENTATION ######
230263 # Calculate factor
231264 fx = 1
0 commit comments