File tree Expand file tree Collapse file tree 2 files changed +20
-2
lines changed Expand file tree Collapse file tree 2 files changed +20
-2
lines changed Original file line number Diff line number Diff line change @@ -3659,10 +3659,22 @@ def dft_slow(x, M):
3659
3659
assert_almost_equal (fft [1 , :, :], np .imag (fft_npy ))
3660
3660
3661
3661
x_val = make_xval ([3 , 4 ]).astype (np .float32 )
3662
- def func (x ):
3662
+ def func1 (x ):
3663
3663
op_ = tf .signal .rfft (x )
3664
3664
return tf .abs (op_ , name = _TFOUTPUT )
3665
- self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val })
3665
+ self ._run_test_case (func1 , [_OUTPUT ], {_INPUT : x_val })
3666
+
3667
+ def func2 (x ):
3668
+ op_ = tf .signal .rfft (x )
3669
+ return tf .cos (op_ , name = _TFOUTPUT )
3670
+ with self .assertRaises (ValueError ):
3671
+ self ._run_test_case (func2 , [_OUTPUT ], {_INPUT : x_val })
3672
+
3673
+ def func3 (x ):
3674
+ op_ = tf .signal .rfft (x )
3675
+ return tf .identity (op_ , name = _TFOUTPUT )
3676
+ with self .assertRaises (ValueError ):
3677
+ self ._run_test_case (func3 , [_OUTPUT ], {_INPUT : x_val })
3666
3678
3667
3679
3668
3680
if __name__ == '__main__' :
Original file line number Diff line number Diff line change @@ -99,6 +99,12 @@ def DFT_real(x, fft_length=None):
99
99
res = np.dot(cst, x)
100
100
return np.transpose(res, (0, 2, 1))
101
101
"""
102
+ consumers = ctx .find_output_consumers (node .output [0 ])
103
+ consumer_types = set (op .type for op in consumers )
104
+ utils .make_sure (
105
+ consumer_types == {'ComplexAbs' },
106
+ "Current implementation of RFFT only allows ComplexAbs as consumer not %r" ,
107
+ consumer_types )
102
108
103
109
onnx_dtype = ctx .get_dtype (node .input [0 ])
104
110
shape = ctx .get_shape (node .input [0 ])
You can’t perform that action at this time.
0 commit comments