Skip to content

Commit 4476f21

Browse files
authored
[5597849] Fix for 'SymbolicShapeInference' error (#453)
Signed-off-by: gcunhase <[email protected]>
1 parent bffe2ff commit 4476f21

File tree

2 files changed

+14
-12
lines changed

2 files changed

+14
-12
lines changed

modelopt/onnx/quantization/graph_utils.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from onnx_graphsurgeon.ir.node import Node
2828
from onnx_graphsurgeon.ir.tensor import Constant, Tensor, Variable
2929
from onnxruntime.quantization.calibrate import CalibrationDataReader
30-
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
3130

3231
from modelopt.onnx.logging_config import logger
3332
from modelopt.onnx.op_types import is_copy_op, is_linear_op
@@ -36,6 +35,7 @@
3635
find_lowest_common_ancestor,
3736
get_child_nodes,
3837
get_parent_nodes,
38+
infer_shapes,
3939
parse_shapes_spec,
4040
save_onnx,
4141
)
@@ -966,7 +966,7 @@ def find_nodes_from_matmul_to_exclude(
966966
logger.debug(f"Found {len(matmul_nodes)} MatMul nodes to analyze")
967967

968968
if calibration_shapes:
969-
nodes_to_exclude = _exclude_matmuls_by_symbolic_inference(
969+
nodes_to_exclude = _exclude_matmuls_by_shape_inference(
970970
model, matmul_nodes, calibration_shapes
971971
)
972972
else:
@@ -1058,10 +1058,10 @@ def find_nodes_from_convs_to_exclude(graph: Graph, quantize_mode: str = "int8"):
10581058
return unsupported_conv_nodes
10591059

10601060

1061-
def _exclude_matmuls_by_symbolic_inference(
1061+
def _exclude_matmuls_by_shape_inference(
10621062
model: onnx.ModelProto, matmul_nodes: list, calibration_shapes: str | dict | None = None
10631063
) -> list[str]:
1064-
"""Use symbolic shape inference to find MatMuls with dimension 1."""
1064+
"""Use shape inference to find MatMuls with dimension 1."""
10651065
# Prepare model for symbolic inference
10661066
for graph_input in model.graph.input:
10671067
for dim in graph_input.type.tensor_type.shape.dim:
@@ -1070,11 +1070,13 @@ def _exclude_matmuls_by_symbolic_inference(
10701070
dim.dim_value = 1
10711071

10721072
# Apply calibration shapes if provided
1073-
input_shapes = (
1074-
parse_shapes_spec(calibration_shapes)
1075-
if (calibration_shapes and isinstance(calibration_shapes, str))
1076-
else {}
1077-
)
1073+
input_shapes = {}
1074+
if calibration_shapes:
1075+
input_shapes = (
1076+
parse_shapes_spec(calibration_shapes)
1077+
if isinstance(calibration_shapes, str)
1078+
else calibration_shapes
1079+
)
10781080
for graph_input in model.graph.input:
10791081
if graph_input.name in input_shapes:
10801082
input_shape = input_shapes[graph_input.name]
@@ -1087,9 +1089,9 @@ def _exclude_matmuls_by_symbolic_inference(
10871089
for dim, new_dim_value in zip(tensor_shape, input_shape):
10881090
dim.dim_value = new_dim_value
10891091

1090-
model.graph.ClearField("value_info")
1091-
model = SymbolicShapeInference.infer_shapes(model)
1092+
model = infer_shapes(model)
10921093
value_info_map = {vi.name: vi for vi in model.graph.value_info}
1094+
value_info_map.update({vi.name: vi for vi in model.graph.output})
10931095

10941096
nodes_to_exclude = []
10951097
for matmul_node in matmul_nodes:

modelopt/onnx/quantization/quantize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ def quantize(
470470
calibration_eps,
471471
)
472472

473-
if not calibration_shapes:
473+
if calibrate_per_node and not calibration_shapes:
474474
calibration_shapes = get_input_shapes(onnx_path)
475475

476476
if quantize_mode in ["fp8", "int8"]:

0 commit comments

Comments
 (0)