@@ -199,13 +199,15 @@ def test_avgppol(self):
199
199
actual , expected = self ._run (output , {x : x_val }, {_INPUT : x_val })
200
200
self .assertAllClose (expected , actual )
201
201
202
- def _conv_test (self , x_val , w , strides = None , padding = "VALID" ):
202
+ def _conv_test (self , x_val , w , strides = None , padding = "VALID" , dilations = None ):
203
203
if strides is None :
204
204
strides = _STRIDE1x1
205
+ if dilations is None :
206
+ dilations = [1 , 1 , 1 , 1 ]
205
207
tf .reset_default_graph ()
206
208
kernel = tf .constant (w , dtype = tf .float32 , name = 'k' )
207
209
x = tf .placeholder (tf .float32 , shape = x_val .shape , name = _TFINPUT )
208
- conv = tf .nn .conv2d (x , kernel , strides = strides , padding = padding )
210
+ conv = tf .nn .conv2d (x , kernel , strides = strides , padding = padding , dilations = dilations )
209
211
output = tf .identity (conv , name = _TFOUTPUT )
210
212
actual , expected = self ._run (output , {x : x_val }, {_INPUT : x_val })
211
213
return actual , expected
@@ -259,6 +261,17 @@ def test_conv2d_6(self):
259
261
expected , actual = self ._conv_test (x_val , kernel_val , strides = strides , padding = "VALID" )
260
262
self .assertAllClose (expected , actual , rtol = 1e-05 )
261
263
264
+
265
+ def test_conv2d_7 (self ):
266
+ x_shape = [1 , 35 , 35 , 288 ] # out: [1, 17, 17, 384]
267
+ kernel_shape = [3 , 3 , 288 , 384 ]
268
+ strides = [1 , 2 , 2 , 1 ]
269
+ dilations = [1 , 3 , 3 , 1 ]
270
+ x_val = np .arange (1 , 1 + np .prod (x_shape )).astype ("float32" ).reshape (x_shape )
271
+ kernel_val = np .arange (1 , 1 + np .prod (kernel_shape )).astype ("float32" ).reshape (kernel_shape )
272
+ expected , actual = self ._conv_test (x_val , kernel_val , strides = strides , padding = "VALID" , dilations = dilations )
273
+ self .assertAllClose (expected , actual , rtol = 1e-05 )
274
+
262
275
def test_conv2d_transpose (self ):
263
276
x_shape = [2 , 6 , 4 , 3 ]
264
277
output_shape = [2 , 13 , 9 , 2 ]
0 commit comments