Skip to content

Commit 720cfb7

Browse files
committed
lint
Signed-off-by: xavier dupré <[email protected]>
1 parent a47f590 commit 720cfb7

File tree

2 files changed

+26
-21
lines changed

2 files changed

+26
-21
lines changed

tests/test_backend.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5493,8 +5493,10 @@ def test_rfft2d_ops_fft_length(self):
54935493
def func1_length(x):
54945494
op_ = tf.signal.rfft2d(x, np.array([3, 3], dtype=np.int32))
54955495
return tf.abs(op_, name=_TFOUTPUT)
5496-
self._run_test_case(func1_length, [_OUTPUT], {_INPUT: x_val}, optimize=False)
5497-
self._run_test_case(func1_length, [_OUTPUT], {_INPUT: x_val})
5496+
with self.subTest(optimize=False):
5497+
self._run_test_case(func1_length, [_OUTPUT], {_INPUT: x_val}, optimize=False)
5498+
with self.subTest(optimize=True):
5499+
self._run_test_case(func1_length, [_OUTPUT], {_INPUT: x_val})
54985500

54995501
@check_tf_min_version("1.14")
55005502
@check_opset_min_version(10, "Slice")
@@ -5503,12 +5505,13 @@ def test_rfft2d_ops_fft_length_many(self):
55035505
for j in range(7, 4, -1):
55045506
for m in range(0, 3):
55055507
for n in range(0, 3):
5506-
with self.subTest(shape=(i, j), fft_length=(m, n)):
5507-
x_val = make_xval([i, j]).astype(np.float32) / 100
5508-
def func1_length(x):
5509-
op_ = tf.signal.rfft2d(x, np.array([i-m, j-n], dtype=np.int32))
5510-
return tf.abs(op_, name=_TFOUTPUT)
5511-
self._run_test_case(func1_length, [_OUTPUT], {_INPUT: x_val})
5508+
for opt in [False, True]:
5509+
with self.subTest(shape=(i, j), fft_length=(m, n), optimize=opt):
5510+
x_val = make_xval([i, j]).astype(np.float32) / 100
5511+
def func1_length(x):
5512+
op_ = tf.signal.rfft2d(x, np.array([i-m, j-n], dtype=np.int32))
5513+
return tf.abs(op_, name=_TFOUTPUT)
5514+
self._run_test_case(func1_length, [_OUTPUT], {_INPUT: x_val}, optimize=opt)
55125515

55135516
@check_tf_min_version("1.14")
55145517
@check_opset_min_version(10, "Slice")
@@ -5543,21 +5546,23 @@ def test_rfft2d_ops_specific_dimension(self):
55435546
def func1(x):
55445547
op_ = tf.signal.rfft2d(x, np.array([1, 4], dtype=np.int32))
55455548
return tf.abs(op_, name=_TFOUTPUT)
5546-
self._run_test_case(func1, [_OUTPUT], {_INPUT: x_val}, optimize=False)
5547-
self._run_test_case(func1, [_OUTPUT], {_INPUT: x_val})
5549+
with self.subTest(shape=(3, 1, 4), fft_length=(1, 4), optimize=False):
5550+
self._run_test_case(func1, [_OUTPUT], {_INPUT: x_val}, optimize=False)
5551+
with self.subTest(shape=(3, 1, 4), fft_length=(1, 4), optimize=True):
5552+
self._run_test_case(func1, [_OUTPUT], {_INPUT: x_val})
55485553

55495554
for shape in [(3, 1, 4), (5, 7), (3, 5, 7), (7, 5)]:
55505555
for fft_length in [shape[-2:], (1, shape[-1]),
55515556
(min(2, shape[-2]), shape[-1]),
55525557
(shape[-2], 2),
55535558
(min(3, shape[-2]), min(4, shape[-2]))]:
5554-
with self.subTests(shape=shape, fft_length=fft_length):
5555-
x_val = make_xval(list(shape)).astype(np.float32)
5556-
def func1(x):
5557-
op_ = tf.signal.rfft2d(x, np.array(fft_length, dtype=np.int32))
5558-
return tf.abs(op_, name=_TFOUTPUT)
5559-
self._run_test_case(func1, [_OUTPUT], {_INPUT: x_val}, optimize=False)
5560-
self._run_test_case(func1, [_OUTPUT], {_INPUT: x_val})
5559+
for optimize in [False, True]:
5560+
with self.subTest(shape=shape, fft_length=fft_length, optimize=optimize):
5561+
x_val = make_xval(list(shape)).astype(np.float32)
5562+
def func1(x):
5563+
op_ = tf.signal.rfft2d(x, np.array(fft_length, dtype=np.int32))
5564+
return tf.abs(op_, name=_TFOUTPUT)
5565+
self._run_test_case(func1, [_OUTPUT], {_INPUT: x_val}, optimize=optimize)
55615566

55625567
@check_tf_min_version("2.1")
55635568
@skip_tflite("TFlite errors on some attributes")

tf2onnx/onnx_opset/signal.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,7 @@ def onnx_rfft_2d_any_test(x, fft_length):
466466
consumer_types == {'ComplexAbs'},
467467
"Current implementation of RFFT2D only allows ComplexAbs as consumer not %r",
468468
consumer_types)
469-
469+
470470
oldnode = node
471471
fft_length = node.input[1]
472472
onnx_dtype = ctx.get_dtype(node.input[0])
@@ -475,7 +475,7 @@ def onnx_rfft_2d_any_test(x, fft_length):
475475
"Unsupported input type.")
476476

477477
fft_length_node = ctx.make_node(
478-
'Cast', inputs=[node.input[1]], attr={'to': onnx_pb.TensorProto.INT64},
478+
'Cast', inputs=[fft_length], attr={'to': onnx_pb.TensorProto.INT64},
479479
name=utils.make_name('fft_length_cast'))
480480
varx = {"x": node.input[0], "fft_length": fft_length_node.output[0]}
481481

@@ -492,7 +492,7 @@ def onnx_rfft_2d_any_test(x, fft_length):
492492
value = np.array([1], dtype=np.int64)
493493
varx['Co_Concatcst'] = ctx.make_const(name=utils.make_name('init_Co_Concatcst'), np_val=value).name
494494

495-
value = np.array([-6.2831854820251465], dtype=np.float32)
495+
value = np.array([-6.2831854820251465], dtype=np_dtype)
496496
varx['Id_Identitycst'] = ctx.make_const(name=utils.make_name('init_Id_Identitycst'), np_val=value).name
497497

498498
value = np.array([2], dtype=np.int64)
@@ -931,7 +931,7 @@ def onnx_rfft_2d_any_test(x, fft_length):
931931

932932
inputs = [varx['Sl_output0'], varx['Co_concat_result012']]
933933
node = ctx.make_node('Reshape', inputs=inputs, name=utils.make_name('Re_Reshape20'))
934-
varx['y'] = node.output[0]
934+
varx['y'] = node.output[0]
935935

936936
# finalize
937937
if getattr(ctx, 'verbose', False):

0 commit comments

Comments
 (0)