Skip to content

Commit 378f0e9

Browse files
committed
Better error message when it fails due to RFFT
Signed-off-by: xavier dupré <[email protected]>
1 parent b51df2f commit 378f0e9

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

tests/test_backend.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3641,10 +3641,22 @@ def dft_slow(x, M):
36413641
assert_almost_equal(fft[1, :, :], np.imag(fft_npy))
36423642

36433643
x_val = make_xval([3, 4]).astype(np.float32)
3644-
def func(x):
3644+
def func1(x):
36453645
op_ = tf.signal.rfft(x)
36463646
return tf.abs(op_, name=_TFOUTPUT)
3647-
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
3647+
self._run_test_case(func1, [_OUTPUT], {_INPUT: x_val})
3648+
3649+
def func2(x):
3650+
op_ = tf.signal.rfft(x)
3651+
return tf.cos(op_, name=_TFOUTPUT)
3652+
with self.assertRaises(ValueError):
3653+
self._run_test_case(func2, [_OUTPUT], {_INPUT: x_val})
3654+
3655+
def func3(x):
3656+
op_ = tf.signal.rfft(x)
3657+
return tf.identity(op_, name=_TFOUTPUT)
3658+
with self.assertRaises(ValueError):
3659+
self._run_test_case(func2, [_OUTPUT], {_INPUT: x_val})
36483660

36493661

36503662
if __name__ == '__main__':

tf2onnx/onnx_opset/signal.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ def DFT_real(x, fft_length=None):
9999
res = np.dot(cst, x)
100100
return np.transpose(res, (0, 2, 1))
101101
"""
102+
consumers = ctx.find_output_consumers(node.output[0])
103+
consumer_types = set(op.type for op in consumers)
104+
utils.make_sure(consumer_types == {'ComplexAbs'}, "Current implementation of RFFT only allows ComplexAbs as consumer not %r", consumer_types)
102105

103106
onnx_dtype = ctx.get_dtype(node.input[0])
104107
shape = ctx.get_shape(node.input[0])

0 commit comments

Comments
 (0)