Skip to content

Commit 93d9d10

Browse files
committed
add unit test
Signed-off-by: xavier dupré <[email protected]>
1 parent 1aec83c commit 93d9d10

File tree

3 files changed

+707
-642
lines changed

3 files changed

+707
-642
lines changed

tests/backend_test_base.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,13 @@ def run_onnxruntime(self, model_path, inputs, output_names, use_custom_ops=False
8989
# opt.log_severity_level = 0
9090
# opt.log_verbosity_level = 255
9191
# opt.enable_profiling = True
92+
import onnx
93+
with open(model_path, "rb") as f:
94+
onx = onnx.load(f)
95+
print(onx)
96+
from mlprodict.onnxrt import OnnxInference
97+
oinf = OnnxInference(model_path)
98+
oinf.run(inputs, verbose=1, fLOG=print)
9299

93100
m = rt.InferenceSession(model_path, opt, providers=providers)
94101
results = m.run(output_names, inputs)

tests/test_backend.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5367,7 +5367,6 @@ def func(input_val):
53675367
self.config.opset = current_opset
53685368

53695369
@check_tf_min_version("1.14")
5370-
#@skip_tflite("FlexRFFT2D")
53715370
def test_rfft_ops(self):
53725371

53735372
def dft_slow(x, M, fft_length):
@@ -5409,7 +5408,6 @@ def func3(x):
54095408
self._run_test_case(func3, [_OUTPUT], {_INPUT: x_val})
54105409

54115410
@check_tf_min_version("1.14")
5412-
#@skip_tflite("FlexRFFT2D")
54135411
@skip_tfjs("TFJS executes rfft with poor accuracy")
54145412
@check_opset_min_version(10, "Slice")
54155413
def test_rfft_ops_fft_length(self):
@@ -5421,7 +5419,6 @@ def func1_length(x):
54215419
self._run_test_case(func1_length, [_OUTPUT], {_INPUT: x_val})
54225420

54235421
@check_tf_min_version("1.14")
5424-
#@skip_tflite("FlexRFFT2D")
54255422
@skip_tfjs("TFJS executes rfft with poor accuracy")
54265423
@check_opset_min_version(10, "Slice")
54275424
def test_rfft_ops_fft_length_many(self):
@@ -5436,7 +5433,6 @@ def func1_length(x):
54365433
self._run_test_case(func1_length, [_OUTPUT], {_INPUT: x_val})
54375434

54385435
@check_tf_min_version("1.14")
5439-
#@skip_tflite("FlexRFFT2D")
54405436
@check_opset_min_version(10, "Slice")
54415437
def test_rfft_ops_fft_length_many_bigger(self):
54425438
for i in range(4, 7):
@@ -5466,7 +5462,6 @@ def func1_length(x):
54665462
self._run_test_case(func1_length, [_OUTPUT], {_INPUT: x_val})
54675463

54685464
@check_tf_min_version("1.14")
5469-
#@skip_tflite("FlexRFFT2D")
54705465
@check_opset_min_version(10, "Slice")
54715466
def test_rfft2d_ops(self):
54725467

@@ -5491,7 +5486,6 @@ def func3(x):
54915486
self._run_test_case(func3, [_OUTPUT], {_INPUT: x_val})
54925487

54935488
@check_tf_min_version("1.14")
5494-
#@skip_tflite("FlexRFFT2D")
54955489
@check_opset_min_version(10, "Slice")
54965490
def test_rfft2d_ops_fft_length(self):
54975491

@@ -5503,7 +5497,6 @@ def func1_length(x):
55035497
self._run_test_case(func1_length, [_OUTPUT], {_INPUT: x_val})
55045498

55055499
@check_tf_min_version("1.14")
5506-
#@skip_tflite("FlexRFFT2D")
55075500
@check_opset_min_version(10, "Slice")
55085501
def test_rfft2d_ops_fft_length_many(self):
55095502
for i in range(7, 4, -1):
@@ -5541,7 +5534,34 @@ def func(x):
55415534
x_val = np.array([1, 5, 2, 0, 3, 4], dtype=np.int64)
55425535
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
55435536

5537+
@check_tf_min_version("1.14")
5538+
@check_opset_min_version(10, "Slice")
5539+
def test_rfft2d_ops_specific_dimension(self):
5540+
5541+
x_val = make_xval([3, 1, 4]).astype(np.float32)
5542+
5543+
def func1(x):
5544+
op_ = tf.signal.rfft2d(x, np.array([1, 4], dtype=np.int32))
5545+
return tf.abs(op_, name=_TFOUTPUT)
5546+
self._run_test_case(func1, [_OUTPUT], {_INPUT: x_val}, optimize=False)
5547+
self._run_test_case(func1, [_OUTPUT], {_INPUT: x_val})
5548+
5549+
for shape in [(3, 1, 4), (5, 7), (3, 5, 7), (7, 5)]:
5550+
for fft_length in [shape[-2:], (1, shape[-1]),
5551+
(min(2, shape[-2]), shape[-1]),
5552+
(shape[-2], 2),
5553+
(min(3, shape[-2]), min(4, shape[-2]))]:
5554+
with self.subTests(shape=shape, fft_length=fft_length):
5555+
x_val = make_xval(list(shape)).astype(np.float32)
5556+
def func1(x):
5557+
op_ = tf.signal.rfft2d(x, np.array(fft_length, dtype=np.int32))
5558+
return tf.abs(op_, name=_TFOUTPUT)
5559+
self._run_test_case(func1, [_OUTPUT], {_INPUT: x_val}, optimize=False)
5560+
self._run_test_case(func1, [_OUTPUT], {_INPUT: x_val})
55445561

55455562

55465563
if __name__ == '__main__':
5564+
cl = BackendTests()
5565+
cl.setUp()
5566+
cl.test_rfft2d_ops_specific_dimension()
55475567
unittest_main()

0 commit comments

Comments
 (0)