Skip to content

Commit 7876d47

Browse files
committed
addresses PR comments
Signed-off-by: xavier dupré <[email protected]>
1 parent 3112ec8 commit 7876d47

File tree

1 file changed

+4
-8
lines changed

1 file changed

+4
-8
lines changed

tests/test_backend.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,6 @@
1818
import tensorflow as tf
1919

2020
from tensorflow.python.ops import lookup_ops
21-
try:
22-
from tensorflow import signal as tf_signal
23-
except ImportError:
24-
tf_signal = None
2521
from backend_test_base import Tf2OnnxBackendTestBase
2622
# pylint reports unused-wildcard-import which is false positive, __all__ is defined in common
2723
from common import * # pylint: disable=wildcard-import,unused-wildcard-import
@@ -3629,24 +3625,24 @@ def test_conv2d_1_kernel_as_input(self):
36293625
[1., 1., 4.]], dtype=np.float32).reshape(_KERNEL3x3)
36303626
self._conv_kernel_as_input_test(x_val, w_val)
36313627

3632-
@unittest.skipIf(tf_signal is None, reason="TF does not have submodule signal.")
3628+
@check_tf_min_version("1.14")
36333629
def test_rfft_ops(self):
36343630

3635-
def DFT_slow(x, M):
3631+
def dft_slow(x, M):
36363632
xt = x.T
36373633
res = np.dot(M, xt)
36383634
return np.transpose(res, (0, 2, 1))
36393635

36403636
x_val = make_xval([2, 4]).astype(np.float32)
36413637
M_both = make_dft_constant(x_val.shape[1], x_val.dtype, x_val.shape[1])
3642-
fft = DFT_slow(x_val, M_both)
3638+
fft = dft_slow(x_val, M_both)
36433639
fft_npy = np.fft.rfft(x_val)
36443640
assert_almost_equal(fft[0, :, :], np.real(fft_npy))
36453641
assert_almost_equal(fft[1, :, :], np.imag(fft_npy))
36463642

36473643
x_val = make_xval([3, 4]).astype(np.float32)
36483644
def func(x):
3649-
op_ = tf_signal.rfft(x)
3645+
op_ = tf.signal.rfft(x)
36503646
return tf.abs(op_, name=_TFOUTPUT)
36513647
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
36523648

0 commit comments

Comments
 (0)