Skip to content

Commit a8d5a3c

Browse files
RFFT2D for TFLite (#1621)
* add test for fft2d Signed-off-by: xavier dupré <[email protected]> * fix rfft implements fft2s Signed-off-by: xavier dupré <[email protected]> * lint Signed-off-by: xavier dupré <[email protected]> * lint Signed-off-by: xavier dupré <[email protected]> * remove one unit test not supported by tensorflow Signed-off-by: xavier dupré <[email protected]> * fix unit test Signed-off-by: xavier dupré <[email protected]> * fix missing variable Signed-off-by: xavier dupré <[email protected]> * remove one test Signed-off-by: xavier dupré <[email protected]> * fix fft 1d, fft2d Signed-off-by: xavier dupré <[email protected]> * lint Signed-off-by: xavier dupré <[email protected]> * lint Signed-off-by: xavier dupré <[email protected]> * fix fft2d Signed-off-by: xavier dupré <[email protected]> * lint Signed-off-by: xavier dupré <[email protected]> * adjust min opset Signed-off-by: xavier dupré <[email protected]> * handle the case where fft_length > shape[-1] Signed-off-by: xavier dupré <[email protected]> * add reshape Signed-off-by: xavier dupré <[email protected]> * lint Signed-off-by: xavier dupré <[email protected]> * fix fft2d Signed-off-by: xavier dupré <[email protected]> * add one more comment Signed-off-by: xavier dupré <[email protected]> * Add tflite support for rfft2d Signed-off-by: Tom Wildenhain <[email protected]> Co-authored-by: xavier dupré <[email protected]>
1 parent 8c081ef commit a8d5a3c

File tree

9 files changed

+634
-106
lines changed

9 files changed

+634
-106
lines changed

tests/backend_test_base.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,12 @@ def run_onnxruntime(self, model_path, inputs, output_names, use_custom_ops=False
8383
if use_custom_ops:
8484
from onnxruntime_extensions import get_library_path
8585
opt.register_custom_ops_library(get_library_path())
86+
8687
# in case of issues with the runtime, one can enable more logging
8788
# opt.log_severity_level = 0
8889
# opt.log_verbosity_level = 255
8990
# opt.enable_profiling = True
91+
9092
m = rt.InferenceSession(model_path, opt, providers=providers)
9193
results = m.run(output_names, inputs)
9294
return results
@@ -316,7 +318,15 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
316318
rtol=1e-07, atol=1e-5, mtol=None, convert_var_to_const=True, constant_fold=True,
317319
check_value=True, check_shape=True, check_dtype=True, process_args=None, onnx_feed_dict=None,
318320
graph_validator=None, as_session=False, large_model=False, premade_placeholders=False,
319-
use_custom_ops=False):
321+
use_custom_ops=False, optimize=True):
322+
"""
323+
This function tests all scenarios available through the command line.
324+
The command line always runs the optimizers.
325+
However, they may modify the final graph into something different than the
326+
tested converter implements. Set `optimize=False` to keep the original
327+
set of nodes and helps debugging. However, the same function should
328+
be called with `optimize=True` to test what the user would actually get.
329+
"""
320330
test_tf = not self.config.skip_tf_tests
321331
test_tflite = not self.config.skip_tflite_tests
322332
test_tfjs = not self.config.skip_tfjs_tests
@@ -365,7 +375,8 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
365375
const_node_values=const_node_values,
366376
initialized_tables=initialized_tables,
367377
**process_args)
368-
g = optimizer.optimize_graph(g, catch_errors=False)
378+
if optimize:
379+
g = optimizer.optimize_graph(g, catch_errors=False)
369380
actual = self.run_backend(g, output_names_with_port, onnx_feed_dict, large_model,
370381
use_custom_ops=use_custom_ops)
371382

@@ -395,7 +406,8 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
395406
target=self.config.target,
396407
tflite_path=tflite_path,
397408
**tfl_process_args)
398-
g = optimizer.optimize_graph(g)
409+
if optimize:
410+
g = optimizer.optimize_graph(g)
399411
onnx_feed_dict_without_port = {k.split(':')[0]: v for k, v in onnx_feed_dict.items()}
400412
onnx_tfl_res = self.run_backend(g, tfl_outputs, onnx_feed_dict_without_port,
401413
postfix="_from_tflite", use_custom_ops=use_custom_ops)

tests/test_backend.py

Lines changed: 124 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5367,21 +5367,29 @@ def func(input_val):
53675367
self.config.opset = current_opset
53685368

53695369
@check_tf_min_version("1.14")
5370-
@skip_tflite("FlexRFFT2D")
5370+
#@skip_tflite("FlexRFFT2D")
53715371
def test_rfft_ops(self):
53725372

5373-
def dft_slow(x, M):
5374-
xt = x.T
5375-
res = np.dot(M, xt)
5373+
def dft_slow(x, M, fft_length):
5374+
xt = x[:, :fft_length].T
5375+
size = fft_length // 2 + 1
5376+
res = np.dot(M[:, :, :fft_length], xt)[:, :size, :]
53765377
return np.transpose(res, (0, 2, 1))
53775378

53785379
x_val = make_xval([2, 4]).astype(np.float32)
53795380
M_both = make_dft_constant(x_val.shape[1], x_val.dtype, x_val.shape[1])
5380-
fft = dft_slow(x_val, M_both)
5381+
fft = dft_slow(x_val, M_both, x_val.shape[1])
53815382
fft_npy = np.fft.rfft(x_val)
53825383
assert_almost_equal(fft[0, :, :], np.real(fft_npy))
53835384
assert_almost_equal(fft[1, :, :], np.imag(fft_npy))
53845385

5386+
x_val = make_xval([2, 4]).astype(np.float32)
5387+
M_both = make_dft_constant(x_val.shape[1], x_val.dtype, x_val.shape[1]-1)
5388+
fft = dft_slow(x_val, M_both, x_val.shape[1]-1)
5389+
fft_npy = np.fft.rfft(x_val, x_val.shape[1]-1)
5390+
assert_almost_equal(fft[0, :, :], np.real(fft_npy))
5391+
assert_almost_equal(fft[1, :, :], np.imag(fft_npy))
5392+
53855393
x_val = make_xval([3, 4]).astype(np.float32)
53865394
def func1(x):
53875395
op_ = tf.signal.rfft(x)
@@ -5401,7 +5409,117 @@ def func3(x):
54015409
self._run_test_case(func3, [_OUTPUT], {_INPUT: x_val})
54025410

54035411
@check_tf_min_version("1.14")
5404-
@check_opset_min_version(11, "range")
5412+
#@skip_tflite("FlexRFFT2D")
5413+
@skip_tfjs("TFJS executes rfft with poor accuracy")
5414+
@check_opset_min_version(10, "Slice")
5415+
def test_rfft_ops_fft_length(self):
5416+
5417+
x_val = make_xval([3, 9]).astype(np.float32)
5418+
def func1_length(x):
5419+
op_ = tf.signal.rfft(x, np.array([8], dtype=np.int32))
5420+
return tf.abs(op_, name=_TFOUTPUT)
5421+
self._run_test_case(func1_length, [_OUTPUT], {_INPUT: x_val})
5422+
5423+
@check_tf_min_version("1.14")
5424+
#@skip_tflite("FlexRFFT2D")
5425+
@skip_tfjs("TFJS executes rfft with poor accuracy")
5426+
@check_opset_min_version(10, "Slice")
5427+
def test_rfft_ops_fft_length_many(self):
5428+
for i in range(4, 7):
5429+
for j in range(4, 7):
5430+
for m in range(0, 3):
5431+
with self.subTest(shape=(i, j), fft_length=j-m):
5432+
x_val = make_xval([i, j]).astype(np.float32)
5433+
def func1_length(x):
5434+
op_ = tf.signal.rfft(x, np.array([j-m], dtype=np.int32))
5435+
return tf.abs(op_, name=_TFOUTPUT)
5436+
self._run_test_case(func1_length, [_OUTPUT], {_INPUT: x_val})
5437+
5438+
@check_tf_min_version("1.14")
5439+
#@skip_tflite("FlexRFFT2D")
5440+
@check_opset_min_version(10, "Slice")
5441+
def test_rfft_ops_fft_length_many_bigger(self):
5442+
for i in range(4, 7):
5443+
for j in range(4, 7):
5444+
for m in range(0, 3):
5445+
with self.subTest(shape=(i, j), fft_length=j+m):
5446+
x_val = make_xval([i, j]).astype(np.float32) / 10
5447+
def func1_length(x):
5448+
op_ = tf.signal.rfft(x, np.array([j+m], dtype=np.int32))
5449+
return tf.abs(op_, name=_TFOUTPUT)
5450+
self._run_test_case(func1_length, [_OUTPUT], {_INPUT: x_val})
5451+
5452+
@check_tf_min_version("1.14")
5453+
@skip_tflite("Slight accuracy issues with some shapes")
5454+
@skip_tfjs("TFJS executes rfft with poor accuracy")
5455+
@check_opset_min_version(10, "Slice")
5456+
def test_rfft_ops_fft_length_many_larger(self):
5457+
for i in range(4, 7):
5458+
for j in range(4, 7):
5459+
for m in range(-3, 3):
5460+
with self.subTest(shape=(3, i, j), fft_length=j+m):
5461+
x_val = make_xval([3, i, j]).astype(np.float32) / 10
5462+
def func1_length(x):
5463+
op_ = tf.signal.rfft(x, np.array([j+m], dtype=np.int32))
5464+
return tf.abs(op_, name=_TFOUTPUT)
5465+
self._run_test_case(func1_length, [_OUTPUT], {_INPUT: x_val}, optimize=False)
5466+
self._run_test_case(func1_length, [_OUTPUT], {_INPUT: x_val})
5467+
5468+
@check_tf_min_version("1.14")
5469+
#@skip_tflite("FlexRFFT2D")
5470+
@check_opset_min_version(10, "Slice")
5471+
def test_rfft2d_ops(self):
5472+
5473+
x_val = make_xval([3, 4]).astype(np.float32)
5474+
5475+
def func1(x):
5476+
op_ = tf.signal.rfft2d(x)
5477+
return tf.abs(op_, name=_TFOUTPUT)
5478+
self._run_test_case(func1, [_OUTPUT], {_INPUT: x_val}, optimize=False)
5479+
self._run_test_case(func1, [_OUTPUT], {_INPUT: x_val})
5480+
5481+
def func2(x):
5482+
op_ = tf.signal.rfft2d(x)
5483+
return tf.cos(op_, name=_TFOUTPUT)
5484+
with self.assertRaises(ValueError):
5485+
self._run_test_case(func2, [_OUTPUT], {_INPUT: x_val})
5486+
5487+
def func3(x):
5488+
op_ = tf.signal.rfft2d(x)
5489+
return tf.identity(op_, name=_TFOUTPUT)
5490+
with self.assertRaises(ValueError):
5491+
self._run_test_case(func3, [_OUTPUT], {_INPUT: x_val})
5492+
5493+
@check_tf_min_version("1.14")
5494+
#@skip_tflite("FlexRFFT2D")
5495+
@check_opset_min_version(10, "Slice")
5496+
def test_rfft2d_ops_fft_length(self):
5497+
5498+
x_val = make_xval([3, 4]).astype(np.float32)
5499+
def func1_length(x):
5500+
op_ = tf.signal.rfft2d(x, np.array([3, 3], dtype=np.int32))
5501+
return tf.abs(op_, name=_TFOUTPUT)
5502+
self._run_test_case(func1_length, [_OUTPUT], {_INPUT: x_val}, optimize=False)
5503+
self._run_test_case(func1_length, [_OUTPUT], {_INPUT: x_val})
5504+
5505+
@check_tf_min_version("1.14")
5506+
#@skip_tflite("FlexRFFT2D")
5507+
@check_opset_min_version(10, "Slice")
5508+
def test_rfft2d_ops_fft_length_many(self):
5509+
for i in range(7, 4, -1):
5510+
for j in range(7, 4, -1):
5511+
for m in range(0, 3):
5512+
for n in range(0, 3):
5513+
with self.subTest(shape=(i, j), fft_length=(m, n)):
5514+
x_val = make_xval([i, j]).astype(np.float32) / 100
5515+
def func1_length(x):
5516+
op_ = tf.signal.rfft2d(x, np.array([i-m, j-n], dtype=np.int32))
5517+
return tf.abs(op_, name=_TFOUTPUT)
5518+
self._run_test_case(func1_length, [_OUTPUT], {_INPUT: x_val})
5519+
5520+
@check_tf_min_version("1.14")
5521+
@check_opset_min_version(10, "Slice")
5522+
@unittest.skipIf(True, reason="Not fully implemented for dynamic shape.")
54055523
def test_fft_ops(self):
54065524
x_val = make_xval([3, 4]).astype(np.float32)
54075525
def func1(x):

tests/tfhub/_tools.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,8 @@ def check_discrepencies(out1, out2, threshold=1e-3):
196196

197197

198198
def benchmark(url, dest, onnx_name, opset, imgs, verbose=True, threshold=1e-3,
199-
signature=None, tag=None, output_name=None, ort_name=None):
199+
signature=None, tag=None, output_name=None, ort_name=None,
200+
optimize=True):
200201
"""
201202
Runs a simple benchmark.
202203
Goes through every steps (download, convert).
@@ -225,7 +226,12 @@ def benchmark(url, dest, onnx_name, opset, imgs, verbose=True, threshold=1e-3,
225226
onnx_name = onnx_name_unzipped
226227

227228
# Benchmarks both models.
228-
ort = onnxruntime.InferenceSession(onnx_name)
229+
if optimize:
230+
ort = onnxruntime.InferenceSession(onnx_name)
231+
else:
232+
so = onnxruntime.SessionOptions()
233+
so.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
234+
ort = onnxruntime.InferenceSession(onnx_name)
229235

230236
if verbose:
231237
print("ONNX inputs:")

tests/tfhub/tfhub_humpback_whale.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,12 @@ def main(opset=13):
1010
name = "humpback-whale"
1111
onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
1212

13-
imgs = generate_random_images(shape=(1, 331, 331, 3))
13+
imgs = generate_random_images(shape=(1, 1024, 1))
14+
inputs = [dict(waveform=img,
15+
context_step_samples=numpy.array(512, dtype=numpy.int64))
16+
for img in imgs]
1417

15-
benchmark(url, dest, onnx_name, opset, imgs)
18+
benchmark(url, dest, onnx_name, opset, inputs, optimize=False)
1619

1720

1821
if __name__ == "__main__":

0 commit comments

Comments
 (0)