|
34 | 34 |
|
35 | 35 | _STRIDE1x1 = [1, 1, 1, 1]
|
36 | 36 | _KERNEL3x3 = [3, 3, 1, 1]
|
| 37 | +_DILATIONS1x1 = [1, 1, 1, 1] |
37 | 38 |
|
38 | 39 | # names for input and outputs for tests
|
39 | 40 | _TFINPUT = "input"
|
@@ -348,7 +349,7 @@ def _conv_test(self, x_val, w, strides=None, padding="VALID", dilations=None, rt
|
348 | 349 | if strides is None:
|
349 | 350 | strides = _STRIDE1x1
|
350 | 351 | if dilations is None:
|
351 |
| - dilations = [1, 1, 1, 1] |
| 352 | + dilations = _DILATIONS1x1 |
352 | 353 | def func(x):
|
353 | 354 | kernel = tf.constant(w, dtype=tf.float32, name='k')
|
354 | 355 | conv = tf.nn.conv2d(x, kernel, strides=strides, padding=padding, dilations=dilations)
|
@@ -3565,8 +3566,26 @@ def func(y, x):
|
3565 | 3566 | self._run_test_case(
|
3566 | 3567 | func, [_OUTPUT], {_INPUT: y_val, _INPUT2: x_val}, rtol=1e-06)
|
3567 | 3568 |
|
3568 |
| - def test_conv2d_kernel_as_input(self): |
3569 |
| - return |
| 3569 | + def _conv_kernel_as_input_test(self, x_val, w_val, strides=None, |
| 3570 | + padding="VALID", dilations=None, rtol=1e-07): |
| 3571 | + if strides is None: |
| 3572 | + strides = _STRIDE1x1 |
| 3573 | + if dilations is None: |
| 3574 | + dilations = _DILATIONS1x1 |
| 3575 | + |
| 3576 | + def func(x, kernel): |
| 3577 | + conv = tf.nn.conv2d(x, kernel, strides=strides, padding=padding, |
| 3578 | + dilations=dilations) |
| 3579 | + return tf.identity(conv, name=_TFOUTPUT) |
| 3580 | + |
| 3581 | + self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT2: w_val}, rtol=rtol) |
| 3582 | + |
| 3583 | + def test_conv2d_1_kernel_as_input(self): |
| 3584 | + x_val = make_xval((1, 1, 5, 5)).transpose(NCHW_TO_NHWC) |
| 3585 | + w_val = np.array([[2., 1., 1.], |
| 3586 | + [1., 3., 1.], |
| 3587 | + [1., 1., 4.]], dtype=np.float32).reshape(_KERNEL3x3) |
| 3588 | + self._conv_kernel_as_input_test(x_val, w_val) |
3570 | 3589 |
|
3571 | 3590 | if __name__ == '__main__':
|
3572 | 3591 | unittest_main()
|
0 commit comments