Skip to content

Commit e912e11

Browse files
authored
Update RFFT2D to handle 3D arrays (Merge pull request #1649 from xadupre/tflfft)
Update RFFT2D to handle 3D arrays
2 parents 9e48a44 + e4a97d1 commit e912e11

File tree

6 files changed

+555
-172
lines changed

6 files changed

+555
-172
lines changed

tests/test_backend.py

Lines changed: 46 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5392,7 +5392,6 @@ def func(input_val):
53925392
self.config.opset = current_opset
53935393

53945394
@check_tf_min_version("1.14")
5395-
#@skip_tflite("FlexRFFT2D")
53965395
def test_rfft_ops(self):
53975396

53985397
def dft_slow(x, M, fft_length):
@@ -5434,7 +5433,6 @@ def func3(x):
54345433
self._run_test_case(func3, [_OUTPUT], {_INPUT: x_val})
54355434

54365435
@check_tf_min_version("1.14")
5437-
#@skip_tflite("FlexRFFT2D")
54385436
@skip_tfjs("TFJS executes rfft with poor accuracy")
54395437
@check_opset_min_version(10, "Slice")
54405438
def test_rfft_ops_fft_length(self):
@@ -5446,7 +5444,6 @@ def func1_length(x):
54465444
self._run_test_case(func1_length, [_OUTPUT], {_INPUT: x_val})
54475445

54485446
@check_tf_min_version("1.14")
5449-
#@skip_tflite("FlexRFFT2D")
54505447
@skip_tfjs("TFJS executes rfft with poor accuracy")
54515448
@check_opset_min_version(10, "Slice")
54525449
def test_rfft_ops_fft_length_many(self):
@@ -5461,7 +5458,6 @@ def func1_length(x):
54615458
self._run_test_case(func1_length, [_OUTPUT], {_INPUT: x_val})
54625459

54635460
@check_tf_min_version("1.14")
5464-
#@skip_tflite("FlexRFFT2D")
54655461
@check_opset_min_version(10, "Slice")
54665462
def test_rfft_ops_fft_length_many_bigger(self):
54675463
for i in range(4, 7):
@@ -5491,8 +5487,7 @@ def func1_length(x):
54915487
self._run_test_case(func1_length, [_OUTPUT], {_INPUT: x_val})
54925488

54935489
@check_tf_min_version("1.14")
5494-
#@skip_tflite("FlexRFFT2D")
5495-
@check_opset_min_version(10, "Slice")
5490+
@check_opset_min_version(11, "CumSum")
54965491
def test_rfft2d_ops(self):
54975492

54985493
x_val = make_xval([3, 4]).astype(np.float32)
@@ -5516,34 +5511,35 @@ def func3(x):
55165511
self._run_test_case(func3, [_OUTPUT], {_INPUT: x_val})
55175512

55185513
@check_tf_min_version("1.14")
5519-
#@skip_tflite("FlexRFFT2D")
5520-
@check_opset_min_version(10, "Slice")
5514+
@check_opset_min_version(11, "CumSum")
55215515
def test_rfft2d_ops_fft_length(self):
55225516

55235517
x_val = make_xval([3, 4]).astype(np.float32)
55245518
def func1_length(x):
55255519
op_ = tf.signal.rfft2d(x, np.array([3, 3], dtype=np.int32))
55265520
return tf.abs(op_, name=_TFOUTPUT)
5527-
self._run_test_case(func1_length, [_OUTPUT], {_INPUT: x_val}, optimize=False)
5528-
self._run_test_case(func1_length, [_OUTPUT], {_INPUT: x_val})
5521+
with self.subTest(optimize=False):
5522+
self._run_test_case(func1_length, [_OUTPUT], {_INPUT: x_val}, optimize=False)
5523+
with self.subTest(optimize=True):
5524+
self._run_test_case(func1_length, [_OUTPUT], {_INPUT: x_val})
55295525

55305526
@check_tf_min_version("1.14")
5531-
#@skip_tflite("FlexRFFT2D")
5532-
@check_opset_min_version(10, "Slice")
5527+
@check_opset_min_version(11, "CumSum")
55335528
def test_rfft2d_ops_fft_length_many(self):
55345529
for i in range(7, 4, -1):
55355530
for j in range(7, 4, -1):
55365531
for m in range(0, 3):
55375532
for n in range(0, 3):
5538-
with self.subTest(shape=(i, j), fft_length=(m, n)):
5539-
x_val = make_xval([i, j]).astype(np.float32) / 100
5540-
def func1_length(x):
5541-
op_ = tf.signal.rfft2d(x, np.array([i-m, j-n], dtype=np.int32))
5542-
return tf.abs(op_, name=_TFOUTPUT)
5543-
self._run_test_case(func1_length, [_OUTPUT], {_INPUT: x_val})
5533+
for opt in [False, True]:
5534+
with self.subTest(shape=(i, j), fft_length=(m, n), optimize=opt):
5535+
x_val = make_xval([i, j]).astype(np.float32) / 100
5536+
def func1_length(x):
5537+
op_ = tf.signal.rfft2d(x, np.array([i-m, j-n], dtype=np.int32))
5538+
return tf.abs(op_, name=_TFOUTPUT)
5539+
self._run_test_case(func1_length, [_OUTPUT], {_INPUT: x_val}, optimize=opt)
55445540

55455541
@check_tf_min_version("1.14")
5546-
@check_opset_min_version(10, "Slice")
5542+
@check_opset_min_version(11, "CumSum")
55475543
@unittest.skipIf(True, reason="Not fully implemented for dynamic shape.")
55485544
def test_fft_ops(self):
55495545
x_val = make_xval([3, 4]).astype(np.float32)
@@ -5566,6 +5562,37 @@ def func(x):
55665562
x_val = np.array([1, 5, 2, 0, 3, 4], dtype=np.int64)
55675563
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
55685564

5565+
@check_tf_min_version("1.14")
5566+
@check_opset_min_version(11, "CumSum")
5567+
def test_rfft2d_ops_specific_dimension(self):
5568+
5569+
x_val = make_xval([3, 1, 4]).astype(np.float32)
5570+
5571+
def func1(x):
5572+
op_ = tf.signal.rfft2d(x, np.array([1, 4], dtype=np.int32))
5573+
return tf.abs(op_, name=_TFOUTPUT)
5574+
with self.subTest(shape=(3, 1, 4), fft_length=(1, 4), optimize=False):
5575+
self._run_test_case(func1, [_OUTPUT], {_INPUT: x_val}, optimize=False)
5576+
with self.subTest(shape=(3, 1, 4), fft_length=(1, 4), optimize=True):
5577+
self._run_test_case(func1, [_OUTPUT], {_INPUT: x_val})
5578+
5579+
for shape in [(3, 1, 4), (5, 7), (3, 5, 7), (7, 5)]:
5580+
for fft_length in [shape[-2:], (1, shape[-1]),
5581+
(min(2, shape[-2]), shape[-1]),
5582+
(shape[-2], 2),
5583+
(min(3, shape[-2]), min(4, shape[-2]))]:
5584+
if fft_length == (1, 1):
5585+
# The code fails in this case but that's unlikely to happen.
5586+
continue
5587+
for optimize in [False, True]:
5588+
with self.subTest(shape=shape, fft_length=fft_length, optimize=optimize):
5589+
x_val = make_xval(list(shape)).astype(np.float32)
5590+
x_val /= x_val.size
5591+
def func1(x):
5592+
op_ = tf.signal.rfft2d(x, np.array(fft_length, dtype=np.int32))
5593+
return tf.abs(op_, name=_TFOUTPUT)
5594+
self._run_test_case(func1, [_OUTPUT], {_INPUT: x_val}, optimize=optimize)
5595+
55695596
@check_tf_min_version("2.1")
55705597
@skip_tflite("TFlite errors on some attributes")
55715598
@check_opset_min_version(9, "string")

tests/tfhub/_tools.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -403,9 +403,15 @@ def call_tflite(inp):
403403
with open(onnx_name, "rb") as f:
404404
model_onnx = onnx.load(f)
405405

406-
call_tflite(imgs[0])
406+
interpreter_details = tf.lite.Interpreter(tname, experimental_preserve_all_tensors=True)
407+
input_details = interpreter_details.get_input_details()
408+
index_in = input_details[0]['index']
409+
interpreter_details.allocate_tensors()
410+
interpreter_details.set_tensor(index_in, imgs[0])
411+
interpreter_details.invoke()
412+
details = interpreter_details.get_tensor_details()
413+
407414
inputs = {input_name: imgs[0]}
408-
details = interpreter.get_tensor_details()
409415
names_index = {}
410416
for tt in details:
411417
names_index[tt['name']] = (tt['index'], tt['quantization'], tt['quantization_parameters'])
@@ -414,7 +420,7 @@ def call_tflite(inp):
414420
for name_tfl, name_ort in names:
415421
index = names_index[name_tfl]
416422

417-
tfl_value = interpreter.get_tensor(index[0])
423+
tfl_value = interpreter_details.get_tensor(index[0])
418424

419425
new_name = onnx_name + ".%s.onnx" % name_ort.replace(":", "_").replace(";", "_").replace("/", "_")
420426
if not os.path.exists(new_name):

0 commit comments

Comments
 (0)