Skip to content

Commit 3d2004b

Browse files
authored
Enable direct INT8/FP8 input in ONNX graph (#354)
Signed-off-by: gcunhase <[email protected]>
1 parent add61db commit 3d2004b

File tree

3 files changed

+82
-3
lines changed

3 files changed

+82
-3
lines changed

modelopt/onnx/quantization/__main__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def get_parser() -> argparse.ArgumentParser:
3636
type=str,
3737
choices=["fp8", "int8", "int4"],
3838
default="int8",
39-
help=("Quantization mode for the given ONNX model."),
39+
help="Quantization mode for the given ONNX model.",
4040
)
4141
argparser.add_argument(
4242
"--calibration_method",
@@ -246,7 +246,8 @@ def get_parser() -> argparse.ArgumentParser:
246246
action="store_true",
247247
help=(
248248
"If True, the I/O types in the quantized ONNX model will be modified to be lower precision whenever "
249-
"possible. Else, they will match the I/O types in the given ONNX model."
249+
"possible. Else, they will match the I/O types in the given ONNX model. "
250+
"The currently supported precisions are {fp16, int8, fp8}."
250251
),
251252
)
252253
return argparser

modelopt/onnx/quantization/qdq_utils.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,78 @@ def remove_input_dq_and_output_q(
871871
return onnx_model
872872

873873

874+
def remove_graph_input_q(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
875+
"""Remove Q nodes from the inputs of a quantized ONNX model.
876+
877+
This supports generating quantized models with low-precision graph I/O.
878+
879+
Args:
880+
onnx_model: ONNX model protobuf to convert
881+
882+
Returns:
883+
ONNX model protobuf with only DQ in the inputs whenever possible.
884+
885+
Raises:
886+
ValueError: If the model is invalid or removal fails
887+
RuntimeError: If graph operations fail
888+
"""
889+
logger.info("Deleting Q nodes in the input of a quantized ONNX model.")
890+
if not isinstance(onnx_model, onnx.ModelProto):
891+
raise ValueError("Input must be an ONNX model protobuf")
892+
893+
graph = onnx_model.graph
894+
if not graph.node:
895+
raise ValueError("Model graph is empty")
896+
897+
initializers, _, tensor_consumers = _get_graph_metadata(graph)
898+
q_nodes = [
899+
(idx, node) for idx, node in enumerate(graph.node) if node.op_type == "QuantizeLinear"
900+
]
901+
q_indices = []
902+
graph_input_names = {inp.name: inp for inp in graph.input}
903+
904+
# Remove Q nodes in the graph inputs
905+
for node_idx, node in q_nodes:
906+
if not any(inp in graph_input_names for inp in node.input):
907+
continue
908+
909+
inp = node.input[0]
910+
for out_name in node.output:
911+
logger.debug(f"Processing QDQ node for output {out_name}")
912+
913+
try:
914+
# Update the Q node output name, each Q should only have one DQ consumer
915+
dq_node = tensor_consumers[out_name]
916+
assert len(dq_node) == 1, f"Expected single consumer for {node.name}"
917+
assert dq_node[0].op_type == "DequantizeLinear", (
918+
f"Expected DequantizeLinear producer for {node.name}"
919+
)
920+
921+
# Rewire graph to connect the graph input to the output of the Q node
922+
dq_node[0].input[0] = inp
923+
924+
# Set the input precision to match the zero-point precision in the DQ node
925+
inp_tensor = graph_input_names[inp]
926+
inp_tensor.type.tensor_type.elem_type = initializers[dq_node[0].input[2]].data_type
927+
928+
# Track QuantizeLinear node indices for cleanup
929+
q_indices.append(node_idx)
930+
931+
except Exception as e:
932+
raise RuntimeError(f"Failed to convert node {node.name}: {e!s}")
933+
934+
# Remove processed nodes
935+
for node_idx in sorted(q_indices, reverse=True):
936+
del graph.node[node_idx]
937+
938+
logger.info(f"Removed {len(q_indices)} Q node{'' if len(q_indices) == 1 else 's'}")
939+
940+
# TODO: remove manual ir_version change once ORT supports ir_version 11
941+
onnx_model.ir_version = 10
942+
943+
return onnx_model
944+
945+
874946
def _cast_initializer_to_dtype(
875947
node: onnx.NodeProto, dtype: str, initializer_map: dict[str, onnx.TensorProto]
876948
):

modelopt/onnx/quantization/quantize.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,11 @@
6060
from modelopt.onnx.quantization.int4 import quantize as quantize_int4
6161
from modelopt.onnx.quantization.int8 import quantize as quantize_int8
6262
from modelopt.onnx.quantization.ort_utils import update_trt_ep_support
63-
from modelopt.onnx.quantization.qdq_utils import qdq_to_dq, remove_input_dq_and_output_q
63+
from modelopt.onnx.quantization.qdq_utils import (
64+
qdq_to_dq,
65+
remove_graph_input_q,
66+
remove_input_dq_and_output_q,
67+
)
6468
from modelopt.onnx.trt_utils import interpret_trt_plugins_precision_flag, load_onnx_model
6569
from modelopt.onnx.utils import duplicate_shared_constants, name_onnx_nodes, save_onnx
6670

@@ -498,6 +502,8 @@ def quantize(
498502
onnx_model = remove_input_dq_and_output_q(
499503
onnx_model, quantizable_custom_ops=custom_ops_to_quantize
500504
)
505+
if direct_io_types:
506+
onnx_model = remove_graph_input_q(onnx_model)
501507
# Sort nodes topologically
502508
graph = gs.import_onnx(onnx_model)
503509
graph.toposort().cleanup()

0 commit comments

Comments
 (0)