@@ -72,6 +72,15 @@ def has_const_input(node: Node) -> bool:
72
72
return any (is_const_input (tensor ) for tensor in node .inputs )
73
73
74
74
75
+ def get_input_shapes (onnx_path : str ) -> dict [str , list [int ]]:
76
+ """Returns the input shapes of the given ONNX model."""
77
+ onnx_model = onnx .load (onnx_path )
78
+ input_shape_dict = {}
79
+ for input in onnx_model .graph .input :
80
+ input_shape_dict [input .name ] = [x .dim_value for x in input .type .tensor_type .shape .dim ]
81
+ return input_shape_dict
82
+
83
+
75
84
def has_path_type (
76
85
node : Node ,
77
86
graph : Graph ,
@@ -923,7 +932,7 @@ def find_nodes_from_matmul_to_exclude(
923
932
intermediate_generated_files : list [str ] | None = None ,
924
933
calibration_data_reader : CalibrationDataReader = None ,
925
934
calibration_eps : list [str ] = ["cpu" , "cuda:0" , "trt" ],
926
- calibration_shapes : str | None = None ,
935
+ calibration_shapes : str | dict | None = None ,
927
936
) -> list [str ]:
928
937
"""Find MatMul nodes that meets gemv condition to exclude.
929
938
@@ -1050,7 +1059,7 @@ def find_nodes_from_convs_to_exclude(graph: Graph, quantize_mode: str = "int8"):
1050
1059
1051
1060
1052
1061
def _exclude_matmuls_by_symbolic_inference (
1053
- model : onnx .ModelProto , matmul_nodes : list , calibration_shapes : str | None = None
1062
+ model : onnx .ModelProto , matmul_nodes : list , calibration_shapes : str | dict | None = None
1054
1063
) -> list [str ]:
1055
1064
"""Use symbolic shape inference to find MatMuls with dimension 1."""
1056
1065
# Prepare model for symbolic inference
@@ -1061,7 +1070,11 @@ def _exclude_matmuls_by_symbolic_inference(
1061
1070
dim .dim_value = 1
1062
1071
1063
1072
# Apply calibration shapes if provided
1064
- input_shapes = parse_shapes_spec (calibration_shapes ) if calibration_shapes else {}
1073
+ input_shapes = (
1074
+ parse_shapes_spec (calibration_shapes )
1075
+ if (calibration_shapes and isinstance (calibration_shapes , str ))
1076
+ else {}
1077
+ )
1065
1078
for graph_input in model .graph .input :
1066
1079
if graph_input .name in input_shapes :
1067
1080
input_shape = input_shapes [graph_input .name ]
0 commit comments