@@ -395,6 +395,24 @@ def test_conv2d_6(self):
395
395
kernel_val = np .arange (1 , 1 + np .prod (kernel_shape )).astype ("float32" ).reshape (kernel_shape )
396
396
self ._conv_test (x_val , kernel_val , strides = strides , padding = "VALID" , rtol = 1e-05 )
397
397
398
+ def test_conv2d_dilation_same (self ):
399
+ x_shape = [1 , 35 , 35 , 288 ] # NHWC
400
+ kernel_shape = [3 , 3 , 288 , 384 ] # [filter_height, filter_width, in_channels, out_channels]
401
+ strides = [1 , 1 , 1 , 1 ] # NHWC
402
+ dilations = [1 , 3 , 1 , 1 ] # NHWC
403
+ x_val = np .arange (1 , 1 + np .prod (x_shape )).astype ("float32" ).reshape (x_shape )
404
+ kernel_val = np .arange (1 , 1 + np .prod (kernel_shape )).astype ("float32" ).reshape (kernel_shape )
405
+ self ._conv_test (x_val , kernel_val , strides = strides , padding = "SAME" , dilations = dilations , rtol = 1e-05 )
406
+
407
+ def test_conv2d_dilation_strides_same (self ):
408
+ x_shape = [1 , 35 , 35 , 288 ] # NHWC
409
+ kernel_shape = [3 , 3 , 288 , 384 ] # [filter_height, filter_width, in_channels, out_channels]
410
+ strides = [1 , 2 , 4 , 1 ] # NHWC
411
+ dilations = [1 , 3 , 1 , 1 ] # NHWC
412
+ x_val = np .arange (1 , 1 + np .prod (x_shape )).astype ("float32" ).reshape (x_shape )
413
+ kernel_val = np .arange (1 , 1 + np .prod (kernel_shape )).astype ("float32" ).reshape (kernel_shape )
414
+ self ._conv_test (x_val , kernel_val , strides = strides , padding = "SAME" , dilations = dilations , rtol = 1e-05 )
415
+
398
416
def test_conv3d_1 (self ):
399
417
strides = [1 , 1 , 1 , 1 , 1 ]
400
418
dilations = [1 , 1 , 1 , 1 , 1 ]
0 commit comments