@@ -71,6 +71,15 @@ def has_const_input(node: Node) -> bool:
71
71
return any (is_const_input (tensor ) for tensor in node .inputs )
72
72
73
73
74
+ def get_input_shapes (onnx_path : str ) -> dict [str , list [int ]]:
75
+ """Returns the input shapes of the given ONNX model."""
76
+ onnx_model = onnx .load (onnx_path )
77
+ input_shape_dict = {}
78
+ for input in onnx_model .graph .input :
79
+ input_shape_dict [input .name ] = [x .dim_value for x in input .type .tensor_type .shape .dim ]
80
+ return input_shape_dict
81
+
82
+
74
83
def has_path_type (
75
84
node : Node ,
76
85
graph : Graph ,
@@ -707,7 +716,7 @@ def find_nodes_from_matmul_to_exclude(
707
716
intermediate_generated_files : list [str ] | None = None ,
708
717
calibration_data_reader : CalibrationDataReader = None ,
709
718
calibration_eps : list [str ] = ["cpu" , "cuda:0" , "trt" ],
710
- calibration_shapes : str | None = None ,
719
+ calibration_shapes : str | dict | None = None ,
711
720
) -> list [str ]:
712
721
"""Find MatMul nodes that meets gemv condition to exclude.
713
722
@@ -834,7 +843,7 @@ def find_nodes_from_convs_to_exclude(graph: Graph, quantize_mode: str = "int8"):
834
843
835
844
836
845
def _exclude_matmuls_by_symbolic_inference (
837
- model : onnx .ModelProto , matmul_nodes : list , calibration_shapes : str | None = None
846
+ model : onnx .ModelProto , matmul_nodes : list , calibration_shapes : str | dict | None = None
838
847
) -> list [str ]:
839
848
"""Use symbolic shape inference to find MatMuls with dimension 1."""
840
849
# Prepare model for symbolic inference
@@ -845,7 +854,11 @@ def _exclude_matmuls_by_symbolic_inference(
845
854
dim .dim_value = 1
846
855
847
856
# Apply calibration shapes if provided
848
- input_shapes = parse_shapes_spec (calibration_shapes ) if calibration_shapes else {}
857
+ input_shapes = (
858
+ parse_shapes_spec (calibration_shapes )
859
+ if (calibration_shapes and isinstance (calibration_shapes , str ))
860
+ else {}
861
+ )
849
862
for graph_input in model .graph .input :
850
863
if graph_input .name in input_shapes :
851
864
input_shape = input_shapes [graph_input .name ]
0 commit comments