24
24
# pylint: disable=unused-argument,missing-docstring,unused-variable,pointless-string-statement,invalid-name
25
25
26
26
27
- @tf_op ("FakeQuantWithMinMaxArgs" )
27
+ @tf_op ([ "FakeQuantWithMinMaxArgs" , "FakeQuantWithMinMaxVars" ] )
28
28
class FakeQuantWithMinMaxArgs :
29
29
# see https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/fake-quant-with-min-max-args
30
30
@classmethod
31
31
def version_10 (cls , ctx , node , ** kwargs ):
32
32
# hack to make up for the missing onnx pack op
33
- amin = node .get_attr ("min" ).f
34
- amax = node .get_attr ("max" ).f
33
+ if node .type == "FakeQuantWithMinMaxVars" :
34
+ utils .make_sure (node .inputs [1 ].is_scalar (), "%s node %s requires const scalar value for min" ,
35
+ node .type , node .name )
36
+ utils .make_sure (node .inputs [2 ].is_scalar (), "%s node %s requires const scalar value for max" ,
37
+ node .type , node .name )
38
+ amin = node .inputs [1 ].get_tensor_value ()
39
+ amax = node .inputs [2 ].get_tensor_value ()
40
+ else :
41
+ amin = node .get_attr ("min" ).f
42
+ amax = node .get_attr ("max" ).f
35
43
narrow_range = node .get_attr ("narrow_range" ).i
36
44
num_bits = node .get_attr ("num_bits" ).i
37
45
@@ -58,10 +66,10 @@ def version_10(cls, ctx, node, **kwargs):
58
66
zero = np .array (- min_adj , dtype = np .uint8 )
59
67
make_sure (
60
68
zero == - min_adj ,
61
- "Cannot convert FakeQuantWithMinMaxArgs with "
69
+ "Cannot convert %s node %s with "
62
70
"min=%r max=%r numbits=%r because zero_scale=%r "
63
71
"is outside uint8 boundary" ,
64
- amin , amax , num_bits , - min_adj )
72
+ node . type , node . name , amin , amax , num_bits , - min_adj )
65
73
zero_point = ctx .make_const (
66
74
utils .make_name ("{}_zpy" .format (node .name )), zero )
67
75
0 commit comments