Skip to content

Commit b18e633

Browse files
committed
Fix axis attribute not found errors in TF < 2.0
1 parent fa1f4b1 commit b18e633

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

tf2onnx/rewriter/quantization_ops_rewriter.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ def create_qdq_nodes(g, match_results):
3333
min_quantized, max_quantized = [0, 255]
3434

3535
# Get axis attribute for per channel implementation.
36-
axis = qdq_node.attr['axis'].i
36+
if 'axis' in qdq_node.attr:
37+
axis = qdq_node.attr['axis'].i
3738

3839
# Get the min and max value of the inputs to QDQ op
3940
min_value = extract_numpy_array(qdq_node.inputs[1])
@@ -61,7 +62,10 @@ def create_qdq_nodes(g, match_results):
6162
if num_channels == 1:
6263
scales = scales[0]
6364
zero_point = zero_point[0]
64-
axis = np.int64(1) # Default value of axis
65+
attrs = {}
66+
else:
67+
utils.make_sure(axis, "Axis must be specified for per channel quantization")
68+
attrs = {'axis': axis}
6569

6670
# Split it into QuantizeLinear and DequantizeLinear and remove the QDQ node reference
6771
inverse_scale = (1/scales).astype(np.float32)
@@ -71,7 +75,7 @@ def create_qdq_nodes(g, match_results):
7175
inputs=[qdq_node.input[0], y_quant_scale.output[0],
7276
y_zero_point.output[0]],
7377
shapes=[qdq_node_output_shape],
74-
attr={'axis': axis},
78+
attr=attrs,
7579
dtypes=[qdq_node_output_dtype],
7680
name=utils.make_name("QuantLinearNode"))
7781

@@ -86,7 +90,7 @@ def create_qdq_nodes(g, match_results):
8690
y_inv_zero_point.output[0]],
8791
outputs=[qdq_node.output[0]],
8892
shapes=[qdq_node_output_shape],
89-
attr={'axis': axis},
93+
attr=attrs,
9094
dtypes=[qdq_node_output_dtype],
9195
name=utils.make_name("DequantLinearNode"))
9296
g.set_shape(dequant_node.output[0], qdq_node_output_shape)

0 commit comments

Comments
 (0)