Skip to content

Commit bd92ea4

Browse files
authored
Merge pull request #45 from onnx/gs/onnx-1.2
add test case for dilations
2 parents 24119b5 + 887449d commit bd92ea4

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

tests/test_backend.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,13 +199,15 @@ def test_avgppol(self):
199199
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
200200
self.assertAllClose(expected, actual)
201201

202-
def _conv_test(self, x_val, w, strides=None, padding="VALID"):
202+
def _conv_test(self, x_val, w, strides=None, padding="VALID", dilations=None):
203203
if strides is None:
204204
strides = _STRIDE1x1
205+
if dilations is None:
206+
dilations = [1, 1, 1, 1]
205207
tf.reset_default_graph()
206208
kernel = tf.constant(w, dtype=tf.float32, name='k')
207209
x = tf.placeholder(tf.float32, shape=x_val.shape, name=_TFINPUT)
208-
conv = tf.nn.conv2d(x, kernel, strides=strides, padding=padding)
210+
conv = tf.nn.conv2d(x, kernel, strides=strides, padding=padding, dilations=dilations)
209211
output = tf.identity(conv, name=_TFOUTPUT)
210212
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
211213
return actual, expected
@@ -259,6 +261,17 @@ def test_conv2d_6(self):
259261
expected, actual = self._conv_test(x_val, kernel_val, strides=strides, padding="VALID")
260262
self.assertAllClose(expected, actual, rtol=1e-05)
261263

264+
265+
def test_conv2d_7(self):
266+
x_shape = [1, 35, 35, 288] # out: [1, 17, 17, 384]
267+
kernel_shape = [3, 3, 288, 384]
268+
strides = [1, 2, 2, 1]
269+
dilations = [1, 3, 3, 1]
270+
x_val = np.arange(1, 1 + np.prod(x_shape)).astype("float32").reshape(x_shape)
271+
kernel_val = np.arange(1, 1 + np.prod(kernel_shape)).astype("float32").reshape(kernel_shape)
272+
expected, actual = self._conv_test(x_val, kernel_val, strides=strides, padding="VALID", dilations=dilations)
273+
self.assertAllClose(expected, actual, rtol=1e-05)
274+
262275
def test_conv2d_transpose(self):
263276
x_shape = [2, 6, 4, 3]
264277
output_shape = [2, 13, 9, 2]

0 commit comments

Comments
 (0)