Skip to content

Commit 57609b0

Browse files
committed
#1070 adding test with kernel as input
1 parent 371e1b6 commit 57609b0

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

tests/test_backend.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
_STRIDE1x1 = [1, 1, 1, 1]
3636
_KERNEL3x3 = [3, 3, 1, 1]
37+
_DILATIONS1x1 = [1, 1, 1, 1]
3738

3839
# names for input and outputs for tests
3940
_TFINPUT = "input"
@@ -348,7 +349,7 @@ def _conv_test(self, x_val, w, strides=None, padding="VALID", dilations=None, rt
348349
if strides is None:
349350
strides = _STRIDE1x1
350351
if dilations is None:
351-
dilations = [1, 1, 1, 1]
352+
dilations = _DILATIONS1x1
352353
def func(x):
353354
kernel = tf.constant(w, dtype=tf.float32, name='k')
354355
conv = tf.nn.conv2d(x, kernel, strides=strides, padding=padding, dilations=dilations)
@@ -3565,8 +3566,26 @@ def func(y, x):
35653566
self._run_test_case(
35663567
func, [_OUTPUT], {_INPUT: y_val, _INPUT2: x_val}, rtol=1e-06)
35673568

3568-
def test_conv2d_kernel_as_input(self):
3569-
return
3569+
def _conv_kernel_as_input_test(self, x_val, w_val, strides=None,
3570+
padding="VALID", dilations=None, rtol=1e-07):
3571+
if strides is None:
3572+
strides = _STRIDE1x1
3573+
if dilations is None:
3574+
dilations = _DILATIONS1x1
3575+
3576+
def func(x, kernel):
3577+
conv = tf.nn.conv2d(x, kernel, strides=strides, padding=padding,
3578+
dilations=dilations)
3579+
return tf.identity(conv, name=_TFOUTPUT)
3580+
3581+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT2: w_val}, rtol=rtol)
3582+
3583+
def test_conv2d_1_kernel_as_input(self):
3584+
x_val = make_xval((1, 1, 5, 5)).transpose(NCHW_TO_NHWC)
3585+
w_val = np.array([[2., 1., 1.],
3586+
[1., 3., 1.],
3587+
[1., 1., 4.]], dtype=np.float32).reshape(_KERNEL3x3)
3588+
self._conv_kernel_as_input_test(x_val, w_val)
35703589

35713590
if __name__ == '__main__':
35723591
unittest_main()

0 commit comments

Comments
 (0)