16
16
17
17
from tf2onnx import utils
18
18
from tf2onnx .handler import tf_op
19
+ from tf2onnx .utils import make_sure
19
20
20
21
logger = logging .getLogger (__name__ )
21
22
@@ -34,14 +35,13 @@ def version_10(cls, ctx, node, **kwargs):
34
35
narrow_range = node .get_attr ("narrow_range" ).i
35
36
num_bits = node .get_attr ("num_bits" ).i
36
37
37
- if narrow_range :
38
- raise RuntimeError (
39
- "Unable to convert node FakeQuantWithMinMaxArgs with "
40
- "narrow_range=%r" % narrow_range )
41
- if num_bits != 8 :
42
- raise RuntimeError (
43
- "Unable to convert node FakeQuantWithMinMaxArgs with "
44
- "num_bits=%r" % num_bits )
38
+ make_sure (
39
+ not narrow_range ,
40
+ "Unable to convert node FakeQuantWithMinMaxArgs with narrow_range=%r" ,
41
+ narrow_range )
42
+ make_sure (num_bits == 8 ,
43
+ "Unable to convert node FakeQuantWithMinMaxArgs with "
44
+ "num_bits=%r" , num_bits )
45
45
46
46
scale = (amax - amin ) / (2 ** num_bits - 1 )
47
47
min_adj = np .around (amin / scale )
@@ -55,12 +55,11 @@ def version_10(cls, ctx, node, **kwargs):
55
55
utils .make_name ("{}_scaley" .format (node .name )),
56
56
np .array (scale , dtype = np .float32 ))
57
57
zero = np .array (- min_adj , dtype = np .uint8 )
58
- if zero != - min_adj :
59
- raise RuntimeError (
60
- "Cannot convert FakeQuantWithMinMaxArgs with "
61
- "min={} max={} numbits={} because zero_scale={} "
62
- "is outside uint8 boundary" .format (
63
- amin , amax , num_bits , - min_adj ))
58
+ make_sure (zero == - min_adj ,
59
+ "Cannot convert FakeQuantWithMinMaxArgs with "
60
+ "min={} max={} numbits={} because zero_scale={} "
61
+ "is outside uint8 boundary" ,
62
+ amin , amax , num_bits , - min_adj )
64
63
zero_point = ctx .make_const (
65
64
utils .make_name ("{}_zpy" .format (node .name )), zero )
66
65
0 commit comments