diff --git a/modelopt/onnx/quantization/__main__.py b/modelopt/onnx/quantization/__main__.py index a27a3b339..1be2e75df 100644 --- a/modelopt/onnx/quantization/__main__.py +++ b/modelopt/onnx/quantization/__main__.py @@ -36,7 +36,7 @@ def get_parser() -> argparse.ArgumentParser: type=str, choices=["fp8", "int8", "int4"], default="int8", - help=("Quantization mode for the given ONNX model."), + help="Quantization mode for the given ONNX model.", ) argparser.add_argument( "--calibration_method", @@ -246,7 +246,8 @@ def get_parser() -> argparse.ArgumentParser: action="store_true", help=( "If True, the I/O types in the quantized ONNX model will be modified to be lower precision whenever " - "possible. Else, they will match the I/O types in the given ONNX model." + "possible. Else, they will match the I/O types in the given ONNX model. " + "The currently supported precisions are {fp16, int8, fp8}." ), ) return argparser diff --git a/modelopt/onnx/quantization/qdq_utils.py b/modelopt/onnx/quantization/qdq_utils.py index 38ed010c0..4b8fe867c 100644 --- a/modelopt/onnx/quantization/qdq_utils.py +++ b/modelopt/onnx/quantization/qdq_utils.py @@ -871,6 +871,78 @@ def remove_input_dq_and_output_q( return onnx_model +def remove_graph_input_q(onnx_model: onnx.ModelProto) -> onnx.ModelProto: + """Remove Q nodes from the inputs of a quantized ONNX model. + + This supports generating quantized models with low-precision graph I/O. + + Args: + onnx_model: ONNX model protobuf to convert + + Returns: + ONNX model protobuf with only DQ in the inputs whenever possible. + + Raises: + ValueError: If the model is invalid or removal fails + RuntimeError: If graph operations fail + """ + logger.info("Deleting Q nodes in the input of a quantized ONNX model.") + if not isinstance(onnx_model, onnx.ModelProto): + raise ValueError("Input must be an ONNX model protobuf") + + graph = onnx_model.graph + if not graph.node: + raise ValueError("Model graph is empty") + + initializers, _, tensor_consumers = _get_graph_metadata(graph) + q_nodes = [ + (idx, node) for idx, node in enumerate(graph.node) if node.op_type == "QuantizeLinear" + ] + q_indices = [] + graph_input_names = {inp.name: inp for inp in graph.input} + + # Remove Q nodes in the graph inputs + for node_idx, node in q_nodes: + if not any(inp in graph_input_names for inp in node.input): + continue + + inp = node.input[0] + for out_name in node.output: + logger.debug(f"Processing QDQ node for output {out_name}") + + try: + # Update the Q node output name, each Q should only have one DQ consumer + dq_node = tensor_consumers[out_name] + assert len(dq_node) == 1, f"Expected single consumer for {node.name}" + assert dq_node[0].op_type == "DequantizeLinear", ( + f"Expected DequantizeLinear producer for {node.name}" + ) + + # Rewire graph to connect the graph input to the output of the Q node + dq_node[0].input[0] = inp + + # Set the input precision to match the zero-point precision in the DQ node + inp_tensor = graph_input_names[inp] + inp_tensor.type.tensor_type.elem_type = initializers[dq_node[0].input[2]].data_type + + # Track QuantizeLinear node indices for cleanup + q_indices.append(node_idx) + + except Exception as e: + raise RuntimeError(f"Failed to convert node {node.name}: {e!s}") + + # Remove processed nodes + for node_idx in sorted(q_indices, reverse=True): + del graph.node[node_idx] + + logger.info(f"Removed {len(q_indices)} Q node{'' if len(q_indices) == 1 else 's'}") + + # TODO: remove manual ir_version change once ORT supports ir_version 11 + onnx_model.ir_version = 10 + + return onnx_model + + def _cast_initializer_to_dtype( node: onnx.NodeProto, dtype: str, initializer_map: dict[str, onnx.TensorProto] ): diff --git a/modelopt/onnx/quantization/quantize.py b/modelopt/onnx/quantization/quantize.py index 8207631c9..9f6f9cae9 100755 --- a/modelopt/onnx/quantization/quantize.py +++ b/modelopt/onnx/quantization/quantize.py @@ -60,7 +60,11 @@ from modelopt.onnx.quantization.int4 import quantize as quantize_int4 from modelopt.onnx.quantization.int8 import quantize as quantize_int8 from modelopt.onnx.quantization.ort_utils import update_trt_ep_support -from modelopt.onnx.quantization.qdq_utils import qdq_to_dq, remove_input_dq_and_output_q +from modelopt.onnx.quantization.qdq_utils import ( + qdq_to_dq, + remove_graph_input_q, + remove_input_dq_and_output_q, +) from modelopt.onnx.trt_utils import interpret_trt_plugins_precision_flag, load_onnx_model from modelopt.onnx.utils import duplicate_shared_constants, name_onnx_nodes, save_onnx @@ -498,6 +502,8 @@ def quantize( onnx_model = remove_input_dq_and_output_q( onnx_model, quantizable_custom_ops=custom_ops_to_quantize ) + if direct_io_types: + onnx_model = remove_graph_input_q(onnx_model) # Sort nodes topologically graph = gs.import_onnx(onnx_model) graph.toposort().cleanup()