Skip to content

Commit 003abc7

Browse files
committed
pylint + fix comments
Signed-off-by: xavier dupré <[email protected]>
1 parent 8de3b19 commit 003abc7

File tree

3 files changed

+26
-28
lines changed

3 files changed

+26
-28
lines changed

tests/test_backend.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from tf2onnx import constants, utils
2525
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
2626
from tf2onnx.tf_loader import is_tf2
27-
from tf2onnx.onnx_opset.signal import DFT_constant
27+
from tf2onnx.onnx_opset.signal import make_dft_constant
2828

2929
# pylint: disable=missing-docstring,invalid-name,unused-argument,function-redefined,cell-var-from-loop
3030

@@ -3633,18 +3633,17 @@ def DFT_slow(x, M):
36333633
return np.transpose(res, (0, 2, 1))
36343634

36353635
x_val = make_xval([2, 4]).astype(np.float32)
3636-
M_both = DFT_constant(x_val.shape[1], x_val.dtype, x_val.shape[1])
3636+
M_both = make_dft_constant(x_val.shape[1], x_val.dtype, x_val.shape[1])
36373637
fft = DFT_slow(x_val, M_both)
36383638
fft_npy = np.fft.rfft(x_val)
36393639
assert_almost_equal(fft[0, :, :], np.real(fft_npy))
36403640
assert_almost_equal(fft[1, :, :], np.imag(fft_npy))
36413641

3642-
for op in [tf.signal.rfft]:
3643-
x_val = make_xval([3, 4]).astype(np.float32)
3644-
def func(x):
3645-
op_ = op(x)
3646-
return tf.abs(op_, name=_TFOUTPUT)
3647-
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
3642+
x_val = make_xval([3, 4]).astype(np.float32)
3643+
def func(x):
3644+
op_ = tf.signal.rfft(x)
3645+
return tf.abs(op_, name=_TFOUTPUT)
3646+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
36483647

36493648

36503649
if __name__ == '__main__':

tf2onnx/onnx_opset/signal.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,23 +14,22 @@
1414
import numpy as np
1515
from onnx import onnx_pb
1616
from onnx.numpy_helper import to_array
17-
from tf2onnx import constants, utils
17+
from tf2onnx import utils
1818
from tf2onnx.handler import tf_op
19-
from tf2onnx.onnx_opset import common
2019

2120
logger = logging.getLogger(__name__)
2221

2322

2423
# pylint: disable=unused-argument,missing-docstring
2524

26-
def DFT_constant(N, dtype, fft_length):
25+
def make_dft_constant(N, dtype, fft_length):
2726
n = np.arange(N)
2827
k = n.reshape((N, 1)).astype(np.float64)
29-
M = np.exp(-2j * np.pi * k * n / N)
30-
M = M[:fft_length // 2 + 1]
31-
both = np.empty((2, ) + M.shape, dtype=dtype)
32-
both[0, :, :] = np.real(M)
33-
both[1, :, :] = np.imag(M)
28+
mat = np.exp(-2j * np.pi * k * n / N)
29+
mat = mat[:fft_length // 2 + 1]
30+
both = np.empty((2,) + mat.shape, dtype=dtype)
31+
both[0, :, :] = np.real(mat)
32+
both[1, :, :] = np.imag(mat)
3433
return both
3534

3635

@@ -52,7 +51,7 @@ def version_1(cls, ctx, node, **kwargs):
5251
<https://jakevdp.github.io/blog/2013/08/28/understanding-the-fft/>`_.
5352
5453
Complex version:
55-
54+
5655
::
5756
5857
import numpy as np
@@ -71,7 +70,7 @@ def DFT(x, fft_length=None):
7170
if fft_length is None:
7271
fft_length = x.shape[0]
7372
cst = _DFT_cst(x.shape[0], fft_length)
74-
return np.dot(cst, x).T
73+
return np.dot(cst, x).T
7574
7675
Real version, first axis is (real, imag) part:
7776
@@ -84,7 +83,7 @@ def _DFT_real_cst(N, fft_length):
8483
k = n.reshape((N, 1)).astype(np.float64)
8584
M = np.exp(-2j * np.pi * k * n / N)
8685
M = M[:fft_length // 2 + 1]
87-
both = np.empty((2, ) + M.shape)
86+
both = np.empty((2,) + M.shape)
8887
both[0, :, :] = np.real(M)
8988
both[1, :, :] = np.imag(M)
9089
return both
@@ -104,7 +103,7 @@ def DFT_real(x, fft_length=None):
104103
onnx_dtype = ctx.get_dtype(node.input[0])
105104
shape = ctx.get_shape(node.input[0])
106105
np_dtype = utils.map_onnx_to_numpy_type(onnx_dtype)
107-
N = shape[-1]
106+
shape_n = shape[-1]
108107
utils.make_sure(len(node.input) == 2, "Two inputs expected not %r", len(node.input))
109108

110109
# This input should be a constant.
@@ -114,13 +113,13 @@ def DFT_real(x, fft_length=None):
114113
"fft_length should be a constant, the other case is not implemented yet.")
115114
value = node_fft_length.get_attr("value")
116115
value_array = to_array(value.t)
117-
utils.make_sure(value_array.shape == (1, ), "Unexpected shape for fft_length (%r)", value_array.shape)
116+
utils.make_sure(value_array.shape == (1,), "Unexpected shape for fft_length (%r)", value_array.shape)
118117
fft_length = value_array[0]
119118

120119
# TODO: handle this parameter when onnx.helper.make_node is fixed.
121120
# Tcomplex = node.get_attr("Tcomplex")
122121

123-
if np_dtype in (np.float16, ):
122+
if np_dtype == np.float16:
124123
res_onnx_dtype = utils.map_numpy_to_onnx_dtype(np.float16)
125124
np_dtype = np.float16
126125
elif np_dtype in (np.float32, np.complex64):
@@ -130,9 +129,9 @@ def DFT_real(x, fft_length=None):
130129
res_onnx_dtype = utils.map_numpy_to_onnx_dtype(np.float64)
131130
np_dtype = np.float64
132131

133-
real_imag_part = DFT_constant(N, np_dtype, fft_length)
132+
real_imag_part = make_dft_constant(shape_n, np_dtype, fft_length)
134133
onx_real_imag_part = ctx.make_const(
135-
name=utils.make_name('cst_rfft_%d' % N), np_val=real_imag_part)
134+
name=utils.make_name('cst_rfft_%d' % shape_n), np_val=real_imag_part)
136135

137136
shapei = list(np.arange(len(shape)))
138137
perm = shapei[:-2] + [shapei[-1], shapei[-2]]
@@ -150,7 +149,7 @@ def DFT_real(x, fft_length=None):
150149
perm = shapei[:-2] + [shapei[-1], shapei[-2]]
151150
last_node = ctx.make_node(
152151
"Transpose", inputs=[mult.output[0]], attr=dict(perm=perm),
153-
name=utils.make_name('CPLX_' + node.name + 'rfft'),
152+
name=utils.make_name('CPLX_' + node.name + 'rfft'),
154153
shapes=[new_shape], dtypes=[res_onnx_dtype])
155154

156155
ctx.replace_all_inputs(node.output[0], last_node.output[0]) # ops=ctx.get_nodes()
@@ -225,5 +224,5 @@ def version_1(cls, ctx, node, **kwargs):
225224
cls.any_version(1, ctx, node, **kwargs)
226225

227226
@classmethod
228-
def version_11(cls, ctx, node, **kwargs):
227+
def version_13(cls, ctx, node, **kwargs):
229228
cls.any_version(11, ctx, node, **kwargs)

tf2onnx/tf_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,8 +236,8 @@ def tflist_to_onnx(g, shape_override, const_node_values=None):
236236
"key_dtype", "value_dtype", "Tin", "Tout", "capacity", "component_types", "shapes",
237237
"Toutput_types",
238238
"Tcomplex", "Treal", # For RFFT, Tcomplex is ignored because
239-
# onnx.helper.make_node fails,
240-
# TODO: it should be added back.
239+
# onnx.helper.make_node fails,
240+
# TODO: it should be added back.
241241
}
242242

243243
node_list = g.get_operations()

0 commit comments

Comments
 (0)