Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions modelopt/onnx/quantization/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def get_parser() -> argparse.ArgumentParser:
type=str,
choices=["fp8", "int8", "int4"],
default="int8",
help=("Quantization mode for the given ONNX model."),
help="Quantization mode for the given ONNX model.",
)
argparser.add_argument(
"--calibration_method",
Expand Down Expand Up @@ -246,7 +246,8 @@ def get_parser() -> argparse.ArgumentParser:
action="store_true",
help=(
"If True, the I/O types in the quantized ONNX model will be modified to be lower precision whenever "
"possible. Else, they will match the I/O types in the given ONNX model."
"possible. Else, they will match the I/O types in the given ONNX model. "
"The currently supported precisions are {fp16, int8, fp8}."
),
)
return argparser
Expand Down
72 changes: 72 additions & 0 deletions modelopt/onnx/quantization/qdq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,78 @@ def remove_input_dq_and_output_q(
return onnx_model


def remove_graph_input_q(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
"""Remove Q nodes from the inputs of a quantized ONNX model.

This supports generating quantized models with low-precision graph I/O.

Args:
onnx_model: ONNX model protobuf to convert

Returns:
ONNX model protobuf with only DQ in the inputs whenever possible.

Raises:
ValueError: If the model is invalid or removal fails
RuntimeError: If graph operations fail
"""
logger.info("Deleting Q nodes in the input of a quantized ONNX model.")
if not isinstance(onnx_model, onnx.ModelProto):
raise ValueError("Input must be an ONNX model protobuf")

graph = onnx_model.graph
if not graph.node:
raise ValueError("Model graph is empty")

initializers, _, tensor_consumers = _get_graph_metadata(graph)
q_nodes = [
(idx, node) for idx, node in enumerate(graph.node) if node.op_type == "QuantizeLinear"
]
q_indices = []
graph_input_names = {inp.name: inp for inp in graph.input}

# Remove Q nodes in the graph inputs
for node_idx, node in q_nodes:
if not any(inp in graph_input_names for inp in node.input):
continue

inp = node.input[0]
for out_name in node.output:
logger.debug(f"Processing QDQ node for output {out_name}")

try:
# Update the Q node output name, each Q should only have one DQ consumer
dq_node = tensor_consumers[out_name]
assert len(dq_node) == 1, f"Expected single consumer for {node.name}"
assert dq_node[0].op_type == "DequantizeLinear", (
f"Expected DequantizeLinear producer for {node.name}"
)

# Rewire graph to connect the graph input to the output of the Q node
dq_node[0].input[0] = inp

# Set the input precision to match the zero-point precision in the DQ node
inp_tensor = graph_input_names[inp]
inp_tensor.type.tensor_type.elem_type = initializers[dq_node[0].input[2]].data_type

# Track QuantizeLinear node indices for cleanup
q_indices.append(node_idx)

except Exception as e:
raise RuntimeError(f"Failed to convert node {node.name}: {e!s}")

# Remove processed nodes
for node_idx in sorted(q_indices, reverse=True):
del graph.node[node_idx]

logger.info(f"Removed {len(q_indices)} Q node{'' if len(q_indices) == 1 else 's'}")

# TODO: remove manual ir_version change once ORT supports ir_version 11
onnx_model.ir_version = 10

return onnx_model


def _cast_initializer_to_dtype(
node: onnx.NodeProto, dtype: str, initializer_map: dict[str, onnx.TensorProto]
):
Expand Down
8 changes: 7 additions & 1 deletion modelopt/onnx/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,11 @@
from modelopt.onnx.quantization.int4 import quantize as quantize_int4
from modelopt.onnx.quantization.int8 import quantize as quantize_int8
from modelopt.onnx.quantization.ort_utils import update_trt_ep_support
from modelopt.onnx.quantization.qdq_utils import qdq_to_dq, remove_input_dq_and_output_q
from modelopt.onnx.quantization.qdq_utils import (
qdq_to_dq,
remove_graph_input_q,
remove_input_dq_and_output_q,
)
from modelopt.onnx.trt_utils import interpret_trt_plugins_precision_flag, load_onnx_model
from modelopt.onnx.utils import duplicate_shared_constants, name_onnx_nodes, save_onnx

Expand Down Expand Up @@ -498,6 +502,8 @@ def quantize(
onnx_model = remove_input_dq_and_output_q(
onnx_model, quantizable_custom_ops=custom_ops_to_quantize
)
if direct_io_types:
onnx_model = remove_graph_input_q(onnx_model)
# Sort nodes topologically
graph = gs.import_onnx(onnx_model)
graph.toposort().cleanup()
Expand Down
Loading