2727from onnx_graphsurgeon .ir .node import Node
2828from onnx_graphsurgeon .ir .tensor import Constant , Tensor , Variable
2929from onnxruntime .quantization .calibrate import CalibrationDataReader
30- from onnxruntime .tools .symbolic_shape_infer import SymbolicShapeInference
3130
3231from modelopt .onnx .logging_config import logger
3332from modelopt .onnx .op_types import is_copy_op , is_linear_op
3635 find_lowest_common_ancestor ,
3736 get_child_nodes ,
3837 get_parent_nodes ,
38+ infer_shapes ,
3939 parse_shapes_spec ,
4040 save_onnx ,
4141)
@@ -966,7 +966,7 @@ def find_nodes_from_matmul_to_exclude(
966966 logger .debug (f"Found { len (matmul_nodes )} MatMul nodes to analyze" )
967967
968968 if calibration_shapes :
969- nodes_to_exclude = _exclude_matmuls_by_symbolic_inference (
969+ nodes_to_exclude = _exclude_matmuls_by_shape_inference (
970970 model , matmul_nodes , calibration_shapes
971971 )
972972 else :
@@ -1058,10 +1058,10 @@ def find_nodes_from_convs_to_exclude(graph: Graph, quantize_mode: str = "int8"):
10581058 return unsupported_conv_nodes
10591059
10601060
1061- def _exclude_matmuls_by_symbolic_inference (
1061+ def _exclude_matmuls_by_shape_inference (
10621062 model : onnx .ModelProto , matmul_nodes : list , calibration_shapes : str | dict | None = None
10631063) -> list [str ]:
1064- """Use symbolic shape inference to find MatMuls with dimension 1."""
1064+ """Use shape inference to find MatMuls with dimension 1."""
10651065 # Prepare model for symbolic inference
10661066 for graph_input in model .graph .input :
10671067 for dim in graph_input .type .tensor_type .shape .dim :
@@ -1070,11 +1070,13 @@ def _exclude_matmuls_by_symbolic_inference(
10701070 dim .dim_value = 1
10711071
10721072 # Apply calibration shapes if provided
1073- input_shapes = (
1074- parse_shapes_spec (calibration_shapes )
1075- if (calibration_shapes and isinstance (calibration_shapes , str ))
1076- else {}
1077- )
1073+ input_shapes = {}
1074+ if calibration_shapes :
1075+ input_shapes = (
1076+ parse_shapes_spec (calibration_shapes )
1077+ if isinstance (calibration_shapes , str )
1078+ else calibration_shapes
1079+ )
10781080 for graph_input in model .graph .input :
10791081 if graph_input .name in input_shapes :
10801082 input_shape = input_shapes [graph_input .name ]
@@ -1087,9 +1089,9 @@ def _exclude_matmuls_by_symbolic_inference(
10871089 for dim , new_dim_value in zip (tensor_shape , input_shape ):
10881090 dim .dim_value = new_dim_value
10891091
1090- model .graph .ClearField ("value_info" )
1091- model = SymbolicShapeInference .infer_shapes (model )
1092+ model = infer_shapes (model )
10921093 value_info_map = {vi .name : vi for vi in model .graph .value_info }
1094+ value_info_map .update ({vi .name : vi for vi in model .graph .output })
10931095
10941096 nodes_to_exclude = []
10951097 for matmul_node in matmul_nodes :
0 commit comments