Skip to content

Commit 3112ec8

Browse files
committed
Update test_backend.py
Signed-off-by: xavier dupré <[email protected]>
1 parent b94de17 commit 3112ec8

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

tests/test_backend.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
import tensorflow as tf
1919

2020
from tensorflow.python.ops import lookup_ops
21+
try:
22+
from tensorflow import signal as tf_signal
23+
except ImportError:
24+
tf_signal = None
2125
from backend_test_base import Tf2OnnxBackendTestBase
2226
# pylint reports unused-wildcard-import which is false positive, __all__ is defined in common
2327
from common import * # pylint: disable=wildcard-import,unused-wildcard-import
@@ -3625,6 +3629,7 @@ def test_conv2d_1_kernel_as_input(self):
36253629
[1., 1., 4.]], dtype=np.float32).reshape(_KERNEL3x3)
36263630
self._conv_kernel_as_input_test(x_val, w_val)
36273631

3632+
@unittest.skipIf(tf_signal is None, reason="TF does not have submodule signal.")
36283633
def test_rfft_ops(self):
36293634

36303635
def DFT_slow(x, M):
@@ -3641,7 +3646,7 @@ def DFT_slow(x, M):
36413646

36423647
x_val = make_xval([3, 4]).astype(np.float32)
36433648
def func(x):
3644-
op_ = tf.signal.rfft(x)
3649+
op_ = tf_signal.rfft(x)
36453650
return tf.abs(op_, name=_TFOUTPUT)
36463651
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
36473652

0 commit comments

Comments
 (0)