File tree Expand file tree Collapse file tree 2 files changed +17
-2
lines changed Expand file tree Collapse file tree 2 files changed +17
-2
lines changed Original file line number Diff line number Diff line change @@ -3641,10 +3641,22 @@ def dft_slow(x, M):
3641
3641
assert_almost_equal (fft [1 , :, :], np .imag (fft_npy ))
3642
3642
3643
3643
x_val = make_xval ([3 , 4 ]).astype (np .float32 )
3644
- def func (x ):
3644
+ def func1 (x ):
3645
3645
op_ = tf .signal .rfft (x )
3646
3646
return tf .abs (op_ , name = _TFOUTPUT )
3647
- self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val })
3647
+ self ._run_test_case (func1 , [_OUTPUT ], {_INPUT : x_val })
3648
+
3649
+ def func2 (x ):
3650
+ op_ = tf .signal .rfft (x )
3651
+ return tf .cos (op_ , name = _TFOUTPUT )
3652
+ with self .assertRaises (ValueError ):
3653
+ self ._run_test_case (func2 , [_OUTPUT ], {_INPUT : x_val })
3654
+
3655
+ def func3 (x ):
3656
+ op_ = tf .signal .rfft (x )
3657
+ return tf .identity (op_ , name = _TFOUTPUT )
3658
+ with self .assertRaises (ValueError ):
3659
+ self ._run_test_case (func2 , [_OUTPUT ], {_INPUT : x_val })
3648
3660
3649
3661
3650
3662
if __name__ == '__main__' :
Original file line number Diff line number Diff line change @@ -99,6 +99,9 @@ def DFT_real(x, fft_length=None):
99
99
res = np.dot(cst, x)
100
100
return np.transpose(res, (0, 2, 1))
101
101
"""
102
+ consumers = ctx .find_output_consumers (node .output [0 ])
103
+ consumer_types = set (op .type for op in consumers )
104
+ utils .make_sure (consumer_types == {'ComplexAbs' }, "Current implementation of RFFT only allows ComplexAbs as consumer not %r" , consumer_types )
102
105
103
106
onnx_dtype = ctx .get_dtype (node .input [0 ])
104
107
shape = ctx .get_shape (node .input [0 ])
You can’t perform that action at this time.
0 commit comments