|
18 | 18 | import tensorflow as tf
|
19 | 19 |
|
20 | 20 | from tensorflow.python.ops import lookup_ops
|
21 |
| -try: |
22 |
| - from tensorflow import signal as tf_signal |
23 |
| -except ImportError: |
24 |
| - tf_signal = None |
25 | 21 | from backend_test_base import Tf2OnnxBackendTestBase
|
26 | 22 | # pylint reports unused-wildcard-import which is false positive, __all__ is defined in common
|
27 | 23 | from common import * # pylint: disable=wildcard-import,unused-wildcard-import
|
@@ -3629,24 +3625,24 @@ def test_conv2d_1_kernel_as_input(self):
|
3629 | 3625 | [1., 1., 4.]], dtype=np.float32).reshape(_KERNEL3x3)
|
3630 | 3626 | self._conv_kernel_as_input_test(x_val, w_val)
|
3631 | 3627 |
|
3632 |
| - @unittest.skipIf(tf_signal is None, reason="TF does not have submodule signal.") |
| 3628 | + @check_tf_min_version("1.14") |
3633 | 3629 | def test_rfft_ops(self):
|
3634 | 3630 |
|
3635 |
| - def DFT_slow(x, M): |
| 3631 | + def dft_slow(x, M): |
3636 | 3632 | xt = x.T
|
3637 | 3633 | res = np.dot(M, xt)
|
3638 | 3634 | return np.transpose(res, (0, 2, 1))
|
3639 | 3635 |
|
3640 | 3636 | x_val = make_xval([2, 4]).astype(np.float32)
|
3641 | 3637 | M_both = make_dft_constant(x_val.shape[1], x_val.dtype, x_val.shape[1])
|
3642 |
| - fft = DFT_slow(x_val, M_both) |
| 3638 | + fft = dft_slow(x_val, M_both) |
3643 | 3639 | fft_npy = np.fft.rfft(x_val)
|
3644 | 3640 | assert_almost_equal(fft[0, :, :], np.real(fft_npy))
|
3645 | 3641 | assert_almost_equal(fft[1, :, :], np.imag(fft_npy))
|
3646 | 3642 |
|
3647 | 3643 | x_val = make_xval([3, 4]).astype(np.float32)
|
3648 | 3644 | def func(x):
|
3649 |
| - op_ = tf_signal.rfft(x) |
| 3645 | + op_ = tf.signal.rfft(x) |
3650 | 3646 | return tf.abs(op_, name=_TFOUTPUT)
|
3651 | 3647 | self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
|
3652 | 3648 |
|
|
0 commit comments