Skip to content

Commit c9b4219

Browse files
authored
Merge pull request #1116 from xadupre/rttf2
Better error message when it fails due to RFFT
2 parents 2cf8854 + c70208b commit c9b4219

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

tests/test_backend.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3659,10 +3659,22 @@ def dft_slow(x, M):
36593659
assert_almost_equal(fft[1, :, :], np.imag(fft_npy))
36603660

36613661
x_val = make_xval([3, 4]).astype(np.float32)
3662-
def func(x):
3662+
def func1(x):
36633663
op_ = tf.signal.rfft(x)
36643664
return tf.abs(op_, name=_TFOUTPUT)
3665-
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
3665+
self._run_test_case(func1, [_OUTPUT], {_INPUT: x_val})
3666+
3667+
def func2(x):
3668+
op_ = tf.signal.rfft(x)
3669+
return tf.cos(op_, name=_TFOUTPUT)
3670+
with self.assertRaises(ValueError):
3671+
self._run_test_case(func2, [_OUTPUT], {_INPUT: x_val})
3672+
3673+
def func3(x):
3674+
op_ = tf.signal.rfft(x)
3675+
return tf.identity(op_, name=_TFOUTPUT)
3676+
with self.assertRaises(ValueError):
3677+
self._run_test_case(func3, [_OUTPUT], {_INPUT: x_val})
36663678

36673679

36683680
if __name__ == '__main__':

tf2onnx/onnx_opset/signal.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,12 @@ def DFT_real(x, fft_length=None):
9999
res = np.dot(cst, x)
100100
return np.transpose(res, (0, 2, 1))
101101
"""
102+
consumers = ctx.find_output_consumers(node.output[0])
103+
consumer_types = set(op.type for op in consumers)
104+
utils.make_sure(
105+
consumer_types == {'ComplexAbs'},
106+
"Current implementation of RFFT only allows ComplexAbs as consumer not %r",
107+
consumer_types)
102108

103109
onnx_dtype = ctx.get_dtype(node.input[0])
104110
shape = ctx.get_shape(node.input[0])

0 commit comments

Comments
 (0)