2727from onnx_graphsurgeon .ir .node import Node
2828from onnx_graphsurgeon .ir .tensor import Constant , Tensor , Variable
2929from onnxruntime .quantization .calibrate import CalibrationDataReader
30- from onnxruntime .tools .symbolic_shape_infer import SymbolicShapeInference
3130
3231from modelopt .onnx .logging_config import logger
3332from modelopt .onnx .op_types import is_copy_op , is_linear_op
@@ -966,7 +965,7 @@ def find_nodes_from_matmul_to_exclude(
966965 logger .debug (f"Found { len (matmul_nodes )} MatMul nodes to analyze" )
967966
968967 if calibration_shapes :
969- nodes_to_exclude = _exclude_matmuls_by_symbolic_inference (
968+ nodes_to_exclude = _exclude_matmuls_by_shape_inference (
970969 model , matmul_nodes , calibration_shapes
971970 )
972971 else :
@@ -1058,10 +1057,10 @@ def find_nodes_from_convs_to_exclude(graph: Graph, quantize_mode: str = "int8"):
10581057 return unsupported_conv_nodes
10591058
10601059
1061- def _exclude_matmuls_by_symbolic_inference (
1060+ def _exclude_matmuls_by_shape_inference (
10621061 model : onnx .ModelProto , matmul_nodes : list , calibration_shapes : str | dict | None = None
10631062) -> list [str ]:
1064- """Use symbolic shape inference to find MatMuls with dimension 1."""
1063+ """Use shape inference to find MatMuls with dimension 1."""
10651064 # Prepare model for symbolic inference
10661065 for graph_input in model .graph .input :
10671066 for dim in graph_input .type .tensor_type .shape .dim :
@@ -1070,11 +1069,13 @@ def _exclude_matmuls_by_symbolic_inference(
10701069 dim .dim_value = 1
10711070
10721071 # Apply calibration shapes if provided
1073- input_shapes = (
1074- parse_shapes_spec (calibration_shapes )
1075- if (calibration_shapes and isinstance (calibration_shapes , str ))
1076- else {}
1077- )
1072+ input_shapes = {}
1073+ if calibration_shapes :
1074+ input_shapes = (
1075+ parse_shapes_spec (calibration_shapes )
1076+ if isinstance (calibration_shapes , str )
1077+ else calibration_shapes
1078+ )
10781079 for graph_input in model .graph .input :
10791080 if graph_input .name in input_shapes :
10801081 input_shape = input_shapes [graph_input .name ]
@@ -1087,8 +1088,7 @@ def _exclude_matmuls_by_symbolic_inference(
10871088 for dim , new_dim_value in zip (tensor_shape , input_shape ):
10881089 dim .dim_value = new_dim_value
10891090
1090- model .graph .ClearField ("value_info" )
1091- model = SymbolicShapeInference .infer_shapes (model )
1091+ model = onnx .shape_inference .infer_shapes (model )
10921092 value_info_map = {vi .name : vi for vi in model .graph .value_info }
10931093
10941094 nodes_to_exclude = []
0 commit comments