Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions modelopt/onnx/quantization/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from onnx_graphsurgeon.ir.node import Node
from onnx_graphsurgeon.ir.tensor import Constant, Tensor, Variable
from onnxruntime.quantization.calibrate import CalibrationDataReader
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference

from modelopt.onnx.logging_config import logger
from modelopt.onnx.op_types import is_copy_op, is_linear_op
Expand All @@ -36,6 +35,7 @@
find_lowest_common_ancestor,
get_child_nodes,
get_parent_nodes,
infer_shapes,
parse_shapes_spec,
save_onnx,
)
Expand Down Expand Up @@ -966,7 +966,7 @@ def find_nodes_from_matmul_to_exclude(
logger.debug(f"Found {len(matmul_nodes)} MatMul nodes to analyze")

if calibration_shapes:
nodes_to_exclude = _exclude_matmuls_by_symbolic_inference(
nodes_to_exclude = _exclude_matmuls_by_shape_inference(
model, matmul_nodes, calibration_shapes
)
else:
Expand Down Expand Up @@ -1058,10 +1058,10 @@ def find_nodes_from_convs_to_exclude(graph: Graph, quantize_mode: str = "int8"):
return unsupported_conv_nodes


def _exclude_matmuls_by_symbolic_inference(
def _exclude_matmuls_by_shape_inference(
model: onnx.ModelProto, matmul_nodes: list, calibration_shapes: str | dict | None = None
) -> list[str]:
"""Use symbolic shape inference to find MatMuls with dimension 1."""
"""Use shape inference to find MatMuls with dimension 1."""
# Prepare model for symbolic inference
for graph_input in model.graph.input:
for dim in graph_input.type.tensor_type.shape.dim:
Expand All @@ -1070,11 +1070,13 @@ def _exclude_matmuls_by_symbolic_inference(
dim.dim_value = 1

# Apply calibration shapes if provided
input_shapes = (
parse_shapes_spec(calibration_shapes)
if (calibration_shapes and isinstance(calibration_shapes, str))
else {}
)
input_shapes = {}
if calibration_shapes:
input_shapes = (
parse_shapes_spec(calibration_shapes)
if isinstance(calibration_shapes, str)
else calibration_shapes
)
for graph_input in model.graph.input:
if graph_input.name in input_shapes:
input_shape = input_shapes[graph_input.name]
Expand All @@ -1087,8 +1089,7 @@ def _exclude_matmuls_by_symbolic_inference(
for dim, new_dim_value in zip(tensor_shape, input_shape):
dim.dim_value = new_dim_value

model.graph.ClearField("value_info")
model = SymbolicShapeInference.infer_shapes(model)
model = infer_shapes(model)
value_info_map = {vi.name: vi for vi in model.graph.value_info}

nodes_to_exclude = []
Expand Down
Loading