Skip to content

Commit 0ea2fdd

Browse files
committed
Fix for 'SymbolicShapeInference' error
Signed-off-by: gcunhase <[email protected]>
1 parent e6e0d2c commit 0ea2fdd

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

modelopt/onnx/quantization/graph_utils.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from onnx_graphsurgeon.ir.node import Node
2828
from onnx_graphsurgeon.ir.tensor import Constant, Tensor, Variable
2929
from onnxruntime.quantization.calibrate import CalibrationDataReader
30-
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
3130

3231
from modelopt.onnx.logging_config import logger
3332
from modelopt.onnx.op_types import is_copy_op, is_linear_op
@@ -966,7 +965,7 @@ def find_nodes_from_matmul_to_exclude(
966965
logger.debug(f"Found {len(matmul_nodes)} MatMul nodes to analyze")
967966

968967
if calibration_shapes:
969-
nodes_to_exclude = _exclude_matmuls_by_symbolic_inference(
968+
nodes_to_exclude = _exclude_matmuls_by_shape_inference(
970969
model, matmul_nodes, calibration_shapes
971970
)
972971
else:
@@ -1058,10 +1057,10 @@ def find_nodes_from_convs_to_exclude(graph: Graph, quantize_mode: str = "int8"):
10581057
return unsupported_conv_nodes
10591058

10601059

1061-
def _exclude_matmuls_by_symbolic_inference(
1060+
def _exclude_matmuls_by_shape_inference(
10621061
model: onnx.ModelProto, matmul_nodes: list, calibration_shapes: str | dict | None = None
10631062
) -> list[str]:
1064-
"""Use symbolic shape inference to find MatMuls with dimension 1."""
1063+
"""Use shape inference to find MatMuls with dimension 1."""
10651064
# Prepare model for symbolic inference
10661065
for graph_input in model.graph.input:
10671066
for dim in graph_input.type.tensor_type.shape.dim:
@@ -1070,11 +1069,13 @@ def _exclude_matmuls_by_symbolic_inference(
10701069
dim.dim_value = 1
10711070

10721071
# Apply calibration shapes if provided
1073-
input_shapes = (
1074-
parse_shapes_spec(calibration_shapes)
1075-
if (calibration_shapes and isinstance(calibration_shapes, str))
1076-
else {}
1077-
)
1072+
input_shapes = {}
1073+
if calibration_shapes:
1074+
input_shapes = (
1075+
parse_shapes_spec(calibration_shapes)
1076+
if isinstance(calibration_shapes, str)
1077+
else calibration_shapes
1078+
)
10781079
for graph_input in model.graph.input:
10791080
if graph_input.name in input_shapes:
10801081
input_shape = input_shapes[graph_input.name]
@@ -1087,8 +1088,7 @@ def _exclude_matmuls_by_symbolic_inference(
10871088
for dim, new_dim_value in zip(tensor_shape, input_shape):
10881089
dim.dim_value = new_dim_value
10891090

1090-
model.graph.ClearField("value_info")
1091-
model = SymbolicShapeInference.infer_shapes(model)
1091+
model = onnx.shape_inference.infer_shapes(model)
10921092
value_info_map = {vi.name: vi for vi in model.graph.value_info}
10931093

10941094
nodes_to_exclude = []

0 commit comments

Comments
 (0)