9
9
from __future__ import unicode_literals
10
10
11
11
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test,import-outside-toplevel
12
- # pylint: disable=wrong-import-position
12
+ # pylint: disable=wrong-import-position,invalid-unary-operand-type
13
13
14
14
import logging
15
15
import os
@@ -106,7 +106,8 @@ def run_backend(self, g, outputs, input_dict, large_model=False, postfix=""):
106
106
raise ValueError ("unknown backend" )
107
107
return y
108
108
109
- def assert_results_equal (self , expected , actual , rtol , atol , check_value = True , check_shape = True , check_dtype = True ):
109
+ def assert_results_equal (self , expected , actual , rtol , atol , mtol = None ,
110
+ check_value = True , check_shape = True , check_dtype = True ):
110
111
for expected_val , actual_val in zip (expected , actual ):
111
112
if check_value :
112
113
if expected_val .dtype == np .object :
@@ -115,6 +116,11 @@ def assert_results_equal(self, expected, actual, rtol, atol, check_value=True, c
115
116
expected_val_str = decode (expected_val )
116
117
self .assertAllEqual (expected_val_str , actual_val )
117
118
else :
119
+ if mtol is not None :
120
+ expected_val = np .minimum (expected_val , mtol )
121
+ expected_val = np .maximum (expected_val , - mtol )
122
+ actual_val = np .minimum (actual_val , mtol )
123
+ actual_val = np .maximum (actual_val , - mtol )
118
124
self .assertAllClose (expected_val , actual_val , rtol = rtol , atol = atol )
119
125
if check_dtype :
120
126
self .assertEqual (expected_val .dtype , actual_val .dtype )
@@ -285,10 +291,10 @@ def get_shape(info):
285
291
self .assertEqual (onnx_shape , tf2onnx_shape )
286
292
self .assertEqual (info .type .tensor_type .elem_type , graph .get_dtype (info .name ))
287
293
288
- def run_test_case (self , func , feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-07 , atol = 1e-5 ,
289
- convert_var_to_const = True , constant_fold = True , check_value = True , check_shape = True ,
290
- check_dtype = True , process_args = None , onnx_feed_dict = None , graph_validator = None , as_session = False ,
291
- large_model = False , premade_placeholders = False ):
294
+ def run_test_case (self , func , feed_dict , input_names_with_port , output_names_with_port ,
295
+ rtol = 1e-07 , atol = 1e-5 , mtol = None , convert_var_to_const = True , constant_fold = True ,
296
+ check_value = True , check_shape = True , check_dtype = True , process_args = None , onnx_feed_dict = None ,
297
+ graph_validator = None , as_session = False , large_model = False , premade_placeholders = False ):
292
298
test_tf = not self .config .skip_tf_tests
293
299
test_tflite = not self .config .skip_tflite_tests
294
300
run_tfl_consistency_test = test_tf and test_tflite and self .config .run_tfl_consistency_test
@@ -330,19 +336,19 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
330
336
g = optimizer .optimize_graph (g , catch_errors = False )
331
337
actual = self .run_backend (g , output_names_with_port , onnx_feed_dict , large_model )
332
338
333
- self .assert_results_equal (expected , actual , rtol , atol , check_value , check_shape , check_dtype )
339
+ self .assert_results_equal (expected , actual , rtol , atol , mtol , check_value , check_shape , check_dtype )
334
340
self .assert_shapes_correct (g , self .config .allow_missing_shapes , not self .config .skip_onnx_checker )
335
341
336
342
if graph_validator :
337
343
self .assertTrue (graph_validator (g ))
338
344
339
345
if test_tflite :
340
- tfl_results , tfl_outputs = self .run_tflite (tflite_path , feed_dict )
341
- test_tflite = tfl_results is not None
346
+ tfl_res , tfl_outputs = self .run_tflite (tflite_path , feed_dict )
347
+ test_tflite = tfl_res is not None
342
348
343
349
if test_tflite :
344
350
if run_tfl_consistency_test :
345
- self .assert_results_equal (expected , tfl_results , rtol , atol , check_value , check_shape , check_dtype )
351
+ self .assert_results_equal (expected , tfl_res , rtol , atol , mtol , check_value , check_shape , check_dtype )
346
352
347
353
tfl_process_args = process_args .copy ()
348
354
if 'inputs_as_nchw' in tfl_process_args :
@@ -358,9 +364,9 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
358
364
** tfl_process_args )
359
365
g = optimizer .optimize_graph (g )
360
366
onnx_feed_dict_without_port = {k .split (':' )[0 ]: v for k , v in onnx_feed_dict .items ()}
361
- onnx_from_tfl_res = self .run_backend (g , tfl_outputs , onnx_feed_dict_without_port , postfix = "_from_tflite" )
367
+ onnx_tfl_res = self .run_backend (g , tfl_outputs , onnx_feed_dict_without_port , postfix = "_from_tflite" )
362
368
363
- self .assert_results_equal (tfl_results , onnx_from_tfl_res , rtol , atol , check_value , check_shape , check_dtype )
369
+ self .assert_results_equal (tfl_res , onnx_tfl_res , rtol , atol , mtol , check_value , check_shape , check_dtype )
364
370
self .assert_shapes_correct (g , self .config .allow_missing_shapes , not self .config .skip_onnx_checker )
365
371
366
372
if graph_validator :
0 commit comments