@@ -36,13 +36,6 @@ def make_dft_constant(length, dtype, fft_length):
36
36
@tf_op ("RFFT" )
37
37
class RFFTOp :
38
38
# support more dtype
39
- supported_dtypes = [
40
- onnx_pb .TensorProto .FLOAT ,
41
- onnx_pb .TensorProto .FLOAT16 ,
42
- onnx_pb .TensorProto .DOUBLE ,
43
- onnx_pb .TensorProto .COMPLEX64 ,
44
- onnx_pb .TensorProto .COMPLEX128 ,
45
- ]
46
39
47
40
@classmethod
48
41
def version_1 (cls , ctx , node , ** kwargs ):
@@ -99,6 +92,13 @@ def DFT_real(x, fft_length=None):
99
92
res = np.dot(cst, x)
100
93
return np.transpose(res, (0, 2, 1))
101
94
"""
95
+ supported_dtypes = [
96
+ onnx_pb .TensorProto .FLOAT ,
97
+ onnx_pb .TensorProto .FLOAT16 ,
98
+ onnx_pb .TensorProto .DOUBLE ,
99
+ onnx_pb .TensorProto .COMPLEX64 ,
100
+ onnx_pb .TensorProto .COMPLEX128 ,
101
+ ]
102
102
consumers = ctx .find_output_consumers (node .output [0 ])
103
103
consumer_types = set (op .type for op in consumers )
104
104
utils .make_sure (
@@ -107,6 +107,7 @@ def DFT_real(x, fft_length=None):
107
107
consumer_types )
108
108
109
109
onnx_dtype = ctx .get_dtype (node .input [0 ])
110
+ utils .make_sure (onnx_dtype in supported_dtypes , "Unsupported input type." )
110
111
shape = ctx .get_shape (node .input [0 ])
111
112
np_dtype = utils .map_onnx_to_numpy_type (onnx_dtype )
112
113
shape_n = shape [- 1 ]
@@ -164,13 +165,6 @@ def DFT_real(x, fft_length=None):
164
165
@tf_op ("ComplexAbs" )
165
166
class ComplexAbsOp :
166
167
# support more dtype
167
- supported_dtypes = [
168
- onnx_pb .TensorProto .FLOAT ,
169
- onnx_pb .TensorProto .FLOAT16 ,
170
- onnx_pb .TensorProto .DOUBLE ,
171
- onnx_pb .TensorProto .COMPLEX64 ,
172
- onnx_pb .TensorProto .COMPLEX128 ,
173
- ]
174
168
175
169
@classmethod
176
170
def any_version (cls , opset , ctx , node , ** kwargs ):
@@ -180,7 +174,15 @@ def any_version(cls, opset, ctx, node, **kwargs):
180
174
it assumes the first dimension means real part (0)
181
175
and imaginary part (1, :, :...).
182
176
"""
177
+ supported_dtypes = [
178
+ onnx_pb .TensorProto .FLOAT ,
179
+ onnx_pb .TensorProto .FLOAT16 ,
180
+ onnx_pb .TensorProto .DOUBLE ,
181
+ onnx_pb .TensorProto .COMPLEX64 ,
182
+ onnx_pb .TensorProto .COMPLEX128 ,
183
+ ]
183
184
onnx_dtype = ctx .get_dtype (node .input [0 ])
185
+ utils .make_sure (onnx_dtype in supported_dtypes , "Unsupported input type." )
184
186
shape = ctx .get_shape (node .input [0 ])
185
187
np_dtype = utils .map_onnx_to_numpy_type (onnx_dtype )
186
188
utils .make_sure (shape [0 ] == 2 , "ComplexAbs expected the first dimension to be 2 but shape is %r" , shape )
0 commit comments