|
34 | 34 |
|
35 | 35 | _STRIDE1x1 = [1, 1, 1, 1]
|
36 | 36 | _KERNEL3x3 = [3, 3, 1, 1]
|
| 37 | +_DILATIONS1x1 = [1, 1, 1, 1] |
37 | 38 |
|
38 | 39 | # names for input and outputs for tests
|
39 | 40 | _TFINPUT = "input"
|
@@ -348,7 +349,7 @@ def _conv_test(self, x_val, w, strides=None, padding="VALID", dilations=None, rt
|
348 | 349 | if strides is None:
|
349 | 350 | strides = _STRIDE1x1
|
350 | 351 | if dilations is None:
|
351 |
| - dilations = [1, 1, 1, 1] |
| 352 | + dilations = _DILATIONS1x1 |
352 | 353 | def func(x):
|
353 | 354 | kernel = tf.constant(w, dtype=tf.float32, name='k')
|
354 | 355 | conv = tf.nn.conv2d(x, kernel, strides=strides, padding=padding, dilations=dilations)
|
@@ -3580,6 +3581,27 @@ def func(y, x):
|
3580 | 3581 | self._run_test_case(
|
3581 | 3582 | func, [_OUTPUT], {_INPUT: y_val, _INPUT2: x_val}, rtol=1e-06)
|
3582 | 3583 |
|
| 3584 | + def _conv_kernel_as_input_test(self, x_val, w_val, strides=None, |
| 3585 | + padding="VALID", dilations=None, rtol=1e-07): |
| 3586 | + if strides is None: |
| 3587 | + strides = _STRIDE1x1 |
| 3588 | + if dilations is None: |
| 3589 | + dilations = _DILATIONS1x1 |
| 3590 | + |
| 3591 | + def func(x, kernel): |
| 3592 | + conv = tf.nn.conv2d(x, kernel, strides=strides, padding=padding, |
| 3593 | + dilations=dilations) |
| 3594 | + return tf.identity(conv, name=_TFOUTPUT) |
| 3595 | + |
| 3596 | + self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT2: w_val}, rtol=rtol) |
| 3597 | + |
| 3598 | + def test_conv2d_1_kernel_as_input(self): |
| 3599 | + x_val = make_xval((1, 1, 5, 5)).transpose(NCHW_TO_NHWC) |
| 3600 | + w_val = np.array([[2., 1., 1.], |
| 3601 | + [1., 3., 1.], |
| 3602 | + [1., 1., 4.]], dtype=np.float32).reshape(_KERNEL3x3) |
| 3603 | + self._conv_kernel_as_input_test(x_val, w_val) |
| 3604 | + |
3583 | 3605 |
|
3584 | 3606 | if __name__ == '__main__':
|
3585 | 3607 | unittest_main()
|
0 commit comments