Skip to content

Commit 543c2ca

Browse files
committed
Replaced onnx.shape_inferend.infer_shapes with utils function
Signed-off-by: gcunhase <[email protected]>
1 parent 57b24fa commit 543c2ca

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

modelopt/onnx/quantization/graph_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
find_lowest_common_ancestor,
3636
get_child_nodes,
3737
get_parent_nodes,
38+
infer_shapes,
3839
parse_shapes_spec,
3940
save_onnx,
4041
)
@@ -1088,7 +1089,7 @@ def _exclude_matmuls_by_shape_inference(
10881089
for dim, new_dim_value in zip(tensor_shape, input_shape):
10891090
dim.dim_value = new_dim_value
10901091

1091-
model = onnx.shape_inference.infer_shapes(model)
1092+
model = infer_shapes(model)
10921093
value_info_map = {vi.name: vi for vi in model.graph.value_info}
10931094

10941095
nodes_to_exclude = []

0 commit comments

Comments
 (0)