Skip to content

Commit 218b6b0

Browse files
xadupresdpython
andauthored
Fix missing argument when calling _get_quantize_input_nodes (#20245)
### Description The current code is calling one method with a missing argument. ### Motivation and Context It breaks Olive's unittests. --------- Co-authored-by: Xavier Dupré <[email protected]>
1 parent a5182a2 commit 218b6b0

File tree

3 files changed

+77
-9
lines changed

3 files changed

+77
-9
lines changed

onnxruntime/python/tools/quantization/onnx_quantizer.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -306,20 +306,19 @@ def is_float_tensor(self, tensor_name):
306306
)
307307
return False
308308

309-
def _get_dynamic_input_quantization_params(self, input_name, nodes_list, qType):
309+
def _get_dynamic_input_quantization_params(self, input_name, nodes_list, qType, initial_type):
310310
"""
311311
Create nodes for dynamic quantization of input and add them to nodes_list.
312312
parameter input_name: Name of the input.
313313
parameter nodes_list: new nodes are appended to this list.
314314
parameter qType: type to quantize to.
315+
parameter initial_type: type to quantize from
315316
return: scale_name, zero_point_name, scale_shape, zero_point_shape.
316317
"""
317318
if qType == onnx_proto.TensorProto.INT8:
318-
return self._get_dynamic_input_quantization_params_int8(input_name, nodes_list)
319+
return self._get_dynamic_input_quantization_params_int8(input_name, nodes_list, initial_type)
319320
if qType == onnx_proto.TensorProto.UINT8:
320-
return self._get_dynamic_input_quantization_params_uint8(input_name, nodes_list)
321-
if qType == onnx_proto.TensorProto.FLOAT8E4M3FN:
322-
return self._get_dynamic_input_quantization_params_float8e4m3fn(input_name, nodes_list)
321+
return self._get_dynamic_input_quantization_params_uint8(input_name, nodes_list, initial_type)
323322
raise ValueError(f"Unexpected value for qType={qType}.")
324323

325324
def _get_dynamic_input_quantization_params_int8(self, input_name, nodes_list, initial_type):
@@ -559,7 +558,9 @@ def _get_quantization_params(self, param_name, use_scale=None, use_zeropoint=Non
559558

560559
return True, scale_name, zero_point_name, scale_shape, zero_point_shape
561560

562-
def _get_quantize_input_nodes(self, node, input_index, qType, given_scale_name=None, given_zp_name=None):
561+
def _get_quantize_input_nodes(
562+
self, node, input_index, qType, given_scale_name=None, given_zp_name=None, initial_type=None
563+
):
563564
"""
564565
Given an input for a node (which is not a initializer), this function
565566
@@ -571,6 +572,7 @@ def _get_quantize_input_nodes(self, node, input_index, qType, given_scale_name=N
571572
:param qType: type to quantize to.
572573
:param given_scale_name: if those inputs need to be quanitzed using this scale tensor.
573574
:param given_zp_name: if those inputs to be quantized using this zeropoint tensor.
575+
:param initial_type: type of the weight to quantize
574576
:return: List of newly created nodes in NodeProto format.
575577
"""
576578
input_name = node.input[input_index]
@@ -606,12 +608,16 @@ def _get_quantize_input_nodes(self, node, input_index, qType, given_scale_name=N
606608
ql_node_name,
607609
)
608610
else:
611+
assert initial_type is not None, (
612+
f"Cannot quantize input without knowing the initial type, "
613+
f"input_name={input_name!r}, input_index={input_index}, qType={qType}, node={node}"
614+
)
609615
(
610616
scale_name,
611617
zp_name,
612618
scale_shape,
613619
zp_shape,
614-
) = self._get_dynamic_input_quantization_params(input_name, nodes, qType)
620+
) = self._get_dynamic_input_quantization_params(input_name, nodes, qType, initial_type=initial_type)
615621
qlinear_node = onnx.helper.make_node(
616622
"QuantizeLinear",
617623
[input_name, scale_name, zp_name],
@@ -794,7 +800,23 @@ def __quantize_inputs(
794800
node_input + "_QuantizeLinear", self.new_nodes, self.model.graph()
795801
)
796802
if qlinear_node is None:
797-
quantize_input_nodes = self._get_quantize_input_nodes(node, input_index, self.activation_qType)
803+
input_name = node.input[input_index]
804+
if input_name in self.value_infos:
805+
value_info = self.value_infos[input_name]
806+
assert value_info.HasField("type"), f"value_info={value_info} has no type."
807+
assert value_info.type.HasField("tensor_type"), f"value_info={value_info} is not a tensor."
808+
initial_type = value_info.type.tensor_type.elem_type
809+
else:
810+
# Shape inference failed. Fallback to self.tensor_names.
811+
assert input_name in self.tensor_names, (
812+
f"shape inference failed for {input_name!r} and "
813+
f"attribute 'tensor_names' does not have any value for "
814+
f"this tensor."
815+
)
816+
initial_type = self.tensor_names[input_name]
817+
quantize_input_nodes = self._get_quantize_input_nodes(
818+
node, input_index, self.activation_qType, initial_type=initial_type
819+
)
798820
if quantize_input_nodes is None:
799821
return (None, None, None, None)
800822
if from_subgraph:

onnxruntime/python/tools/quantization/operators/pad.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def quantize(self):
6868
self.quantizer.activation_qType,
6969
quantized_input_value.scale_name,
7070
quantized_input_value.zp_name,
71+
initial_type=scale_tensor.data_type,
7172
)
7273
self.quantizer.new_nodes.extend(pad_value_qnodes)
7374
node.input[2] = pad_value_qnodes[0].output[0]

onnxruntime/test/python/quantization/test_op_gemm.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -784,7 +784,52 @@ def test_qgemm_ref_uint8_specific_example(self):
784784
got = ref.run(None, feeds)[0]
785785
assert_allclose(expected, got)
786786

787+
def test_dynamic_quantization(self):
788+
# dummy_model.onnx from Olive
789+
model = helper.make_model(
790+
helper.make_graph(
791+
[
792+
helper.make_node(
793+
"Gemm", ["input", "fc1.weight", "fc1.bias"], ["gemm0"], alpha=1.0, beta=1.0, transB=1
794+
),
795+
helper.make_node("Relu", ["gemm0"], ["output"]),
796+
],
797+
"g",
798+
[helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 1])],
799+
[helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, [1, 10])],
800+
[
801+
onnx.numpy_helper.from_array(np.random.randn(10, 1).astype(np.float32), name="fc1.weight"),
802+
onnx.numpy_helper.from_array(np.random.randn(10).astype(np.float32), name="fc1.bias"),
803+
],
804+
),
805+
opset_imports=[helper.make_opsetid("", 18)],
806+
ir_version=9,
807+
)
808+
onnx.checker.check_model(model)
809+
run_config = {
810+
"weight_type": QuantType.QInt8,
811+
"op_types_to_quantize": None,
812+
"nodes_to_quantize": None,
813+
"nodes_to_exclude": None,
814+
"per_channel": False,
815+
"reduce_range": False,
816+
"extra_options": {
817+
"extra.Sigmoid.nnapi": False,
818+
"ActivationSymmetric": False,
819+
"WeightSymmetric": True,
820+
"EnableSubgraph": False,
821+
"ForceQuantizeNoInputCheck": False,
822+
"MatMulConstBOnly": True,
823+
},
824+
}
825+
model_path = "test_dynamic_quantization.onnx"
826+
with open(model_path, "wb") as f:
827+
f.write(model.SerializeToString())
828+
qpath = "test_dynamic_quantization.quantized.onnx"
829+
quantize_dynamic(model_input=model_path, model_output=qpath, use_external_data_format=True, **run_config)
830+
onx = onnx.load(qpath)
831+
self.assertIn("DynamicQuantizeLinear", set(n.op_type for n in onx.graph.node))
832+
787833

788834
if __name__ == "__main__":
789-
TestOpGemm().test_quantize_gemm_e4m3fn_p3()
790835
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)