Skip to content

Commit b53f1bb

Browse files
authored
Merge pull request #1124 from xadupre/sp
Moves unused class attribute to function local variable.
2 parents ba533d0 + a7656f4 commit b53f1bb

File tree

2 files changed

+22
-19
lines changed

2 files changed

+22
-19
lines changed

tf2onnx/onnx_opset/math.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -595,11 +595,6 @@ def version_10(cls, ctx, node, **kwargs):
595595
@tf_op("Atan2")
596596
class Atan2Op:
597597
# support more dtype
598-
supported_dtypes = [
599-
onnx_pb.TensorProto.FLOAT,
600-
onnx_pb.TensorProto.FLOAT16,
601-
onnx_pb.TensorProto.DOUBLE
602-
]
603598

604599
@classmethod
605600
def version_9(cls, ctx, node, **kwargs):
@@ -615,8 +610,14 @@ def atan2(y, x):
615610
atan_part = numpy.arctan(y / (x + (1 - sx ** 2))) * sx ** 2
616611
return atan_part + pi_part
617612
"""
613+
supported_dtypes = [
614+
onnx_pb.TensorProto.FLOAT,
615+
onnx_pb.TensorProto.FLOAT16,
616+
onnx_pb.TensorProto.DOUBLE
617+
]
618618

619619
onnx_dtype = ctx.get_dtype(node.input[0])
620+
utils.make_sure(onnx_dtype in supported_dtypes, "Unsupported input type.")
620621
shape = ctx.get_shape(node.input[0])
621622
np_dtype = utils.map_onnx_to_numpy_type(onnx_dtype)
622623

tf2onnx/onnx_opset/signal.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,6 @@ def make_dft_constant(length, dtype, fft_length):
3636
@tf_op("RFFT")
3737
class RFFTOp:
3838
# support more dtype
39-
supported_dtypes = [
40-
onnx_pb.TensorProto.FLOAT,
41-
onnx_pb.TensorProto.FLOAT16,
42-
onnx_pb.TensorProto.DOUBLE,
43-
onnx_pb.TensorProto.COMPLEX64,
44-
onnx_pb.TensorProto.COMPLEX128,
45-
]
4639

4740
@classmethod
4841
def version_1(cls, ctx, node, **kwargs):
@@ -99,6 +92,13 @@ def DFT_real(x, fft_length=None):
9992
res = np.dot(cst, x)
10093
return np.transpose(res, (0, 2, 1))
10194
"""
95+
supported_dtypes = [
96+
onnx_pb.TensorProto.FLOAT,
97+
onnx_pb.TensorProto.FLOAT16,
98+
onnx_pb.TensorProto.DOUBLE,
99+
onnx_pb.TensorProto.COMPLEX64,
100+
onnx_pb.TensorProto.COMPLEX128,
101+
]
102102
consumers = ctx.find_output_consumers(node.output[0])
103103
consumer_types = set(op.type for op in consumers)
104104
utils.make_sure(
@@ -107,6 +107,7 @@ def DFT_real(x, fft_length=None):
107107
consumer_types)
108108

109109
onnx_dtype = ctx.get_dtype(node.input[0])
110+
utils.make_sure(onnx_dtype in supported_dtypes, "Unsupported input type.")
110111
shape = ctx.get_shape(node.input[0])
111112
np_dtype = utils.map_onnx_to_numpy_type(onnx_dtype)
112113
shape_n = shape[-1]
@@ -164,13 +165,6 @@ def DFT_real(x, fft_length=None):
164165
@tf_op("ComplexAbs")
165166
class ComplexAbsOp:
166167
# support more dtype
167-
supported_dtypes = [
168-
onnx_pb.TensorProto.FLOAT,
169-
onnx_pb.TensorProto.FLOAT16,
170-
onnx_pb.TensorProto.DOUBLE,
171-
onnx_pb.TensorProto.COMPLEX64,
172-
onnx_pb.TensorProto.COMPLEX128,
173-
]
174168

175169
@classmethod
176170
def any_version(cls, opset, ctx, node, **kwargs):
@@ -180,7 +174,15 @@ def any_version(cls, opset, ctx, node, **kwargs):
180174
it assumes the first dimension means real part (0)
181175
and imaginary part (1, :, :...).
182176
"""
177+
supported_dtypes = [
178+
onnx_pb.TensorProto.FLOAT,
179+
onnx_pb.TensorProto.FLOAT16,
180+
onnx_pb.TensorProto.DOUBLE,
181+
onnx_pb.TensorProto.COMPLEX64,
182+
onnx_pb.TensorProto.COMPLEX128,
183+
]
183184
onnx_dtype = ctx.get_dtype(node.input[0])
185+
utils.make_sure(onnx_dtype in supported_dtypes, "Unsupported input type.")
184186
shape = ctx.get_shape(node.input[0])
185187
np_dtype = utils.map_onnx_to_numpy_type(onnx_dtype)
186188
utils.make_sure(shape[0] == 2, "ComplexAbs expected the first dimension to be 2 but shape is %r", shape)

0 commit comments

Comments
 (0)