@@ -3359,23 +3359,37 @@ def func(base_matrix, diag, k):
3359
3359
def test_fakequant_with_min_max (self ):
3360
3360
def func (x ):
3361
3361
ret = fake_quant_with_min_max_args (
3362
- x , min = - 1024 , max = 1024 , num_bits = 8 , narrow_range = False , name = None )
3362
+ x , min = - 1024 , max = 1023 , num_bits = 8 , narrow_range = False , name = None )
3363
3363
return tf .identity (ret , name = _TFOUTPUT )
3364
3364
3365
3365
x_val = np .random .random (size = [4 , 3 ]).astype (np .float32 ) * 2048. - 1024.
3366
3366
x_val0 = np .abs (x_val )
3367
- self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val0 })
3368
- self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val })
3367
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val0 }, rtol = 1e-6 , atol = 1e-4 )
3368
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val }, rtol = 1e-6 , atol = 1e-4 )
3369
+
3370
+ x_val = np .random .random (size = [4 , 3 ]).astype (np .float32 ) * 2048. - 1024
3371
+ x_val [0 , 0 ] = - 1024
3372
+ x_val [0 , 1 ] = - 1023
3373
+ x_val [0 , 2 ] = 1024
3374
+ x_val [1 , 0 ] = 1023
3375
+ x_val [1 , 1 ] = 1025
3376
+ x_val [1 , 2 ] = - 1025
3377
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val }, rtol = 1e-6 , atol = 1e-4 )
3378
+
3379
+ @check_opset_min_version (10 )
3380
+ @check_tf_min_version ("1.14" )
3381
+ def test_fakequant_with_min_max_same_sign (self ):
3382
+ def func_neg (x ):
3383
+ ret = fake_quant_with_min_max_args (
3384
+ x , min = - 1024 * 3 , max = - 1024 , num_bits = 8 , narrow_range = False , name = None )
3385
+ return tf .identity (ret , name = _TFOUTPUT )
3386
+
3387
+ x_val = np .random .random (size = [4 , 3 ]).astype (np .float32 ) * 2048. - 1024 * 3.
3388
+ try :
3389
+ self ._run_test_case (func_neg , [_OUTPUT ], {_INPUT : x_val }, rtol = 1e-6 , atol = 1e-4 )
3390
+ except RuntimeError :
3391
+ pass
3369
3392
3370
3393
3371
3394
if __name__ == '__main__' :
3372
- #cl = BackendTests()
3373
- #cl.setUp()
3374
- #cl.test_fakequant_with_min_max()
3375
- #import cProfile
3376
- #cProfile.run('unittest_main()', 'restats')
3377
3395
unittest_main ()
3378
- #import pstats
3379
- #from pstats import SortKey
3380
- #p = pstats.Stats('restats')
3381
- #p.sort_stats(SortKey.CUMULATIVE).print_stats()
0 commit comments