@@ -33,7 +33,8 @@ def create_qdq_nodes(g, match_results):
33
33
min_quantized , max_quantized = [0 , 255 ]
34
34
35
35
# Get axis attribute for per channel implementation.
36
- axis = qdq_node .attr ['axis' ].i
36
+ if 'axis' in qdq_node .attr :
37
+ axis = qdq_node .attr ['axis' ].i
37
38
38
39
# Get the min and max value of the inputs to QDQ op
39
40
min_value = extract_numpy_array (qdq_node .inputs [1 ])
@@ -61,7 +62,10 @@ def create_qdq_nodes(g, match_results):
61
62
if num_channels == 1 :
62
63
scales = scales [0 ]
63
64
zero_point = zero_point [0 ]
64
- axis = np .int64 (1 ) # Default value of axis
65
+ attrs = {}
66
+ else :
67
+ utils .make_sure (axis , "Axis must be specified for per channel quantization" )
68
+ attrs = {'axis' : axis }
65
69
66
70
# Split it into QuantizeLinear and DequantizeLinear and remove the QDQ node reference
67
71
inverse_scale = (1 / scales ).astype (np .float32 )
@@ -71,7 +75,7 @@ def create_qdq_nodes(g, match_results):
71
75
inputs = [qdq_node .input [0 ], y_quant_scale .output [0 ],
72
76
y_zero_point .output [0 ]],
73
77
shapes = [qdq_node_output_shape ],
74
- attr = { 'axis' : axis } ,
78
+ attr = attrs ,
75
79
dtypes = [qdq_node_output_dtype ],
76
80
name = utils .make_name ("QuantLinearNode" ))
77
81
@@ -86,7 +90,7 @@ def create_qdq_nodes(g, match_results):
86
90
y_inv_zero_point .output [0 ]],
87
91
outputs = [qdq_node .output [0 ]],
88
92
shapes = [qdq_node_output_shape ],
89
- attr = { 'axis' : axis } ,
93
+ attr = attrs ,
90
94
dtypes = [qdq_node_output_dtype ],
91
95
name = utils .make_name ("DequantLinearNode" ))
92
96
g .set_shape (dequant_node .output [0 ], qdq_node_output_shape )
0 commit comments