diff --git a/src/nncf/onnx/graph/nncf_graph_builder.py b/src/nncf/onnx/graph/nncf_graph_builder.py index 50271475bef..7079ddcb772 100644 --- a/src/nncf/onnx/graph/nncf_graph_builder.py +++ b/src/nncf/onnx/graph/nncf_graph_builder.py @@ -12,6 +12,7 @@ from typing import Any, Optional import onnx +from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference import nncf from nncf.common.graph import NNCFGraph @@ -347,7 +348,7 @@ def create_nncf_graph(onnx_model: onnx.ModelProto) -> NNCFGraph: :return: NNCFGraph. """ onnx_model = GraphConverter._replace_empty_node_name(onnx_model) - onnx_model = onnx.shape_inference.infer_shapes(onnx_model) + onnx_model = SymbolicShapeInference.infer_shapes(onnx_model) edge_info_mapping = get_edge_info_mapping(onnx_model) children_node_mapping = get_children_node_mapping(onnx_model) parents_node_mapping = get_parents_node_mapping(onnx_model) diff --git a/src/nncf/onnx/graph/passes.py b/src/nncf/onnx/graph/passes.py index f723e6cfc06..a3d7abe95f5 100644 --- a/src/nncf/onnx/graph/passes.py +++ b/src/nncf/onnx/graph/passes.py @@ -11,6 +11,7 @@ import onnx from onnx.reference.ops import load_op +from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference from nncf.onnx.graph.onnx_helper import get_children from nncf.onnx.graph.onnx_helper import get_children_node_mapping @@ -74,8 +75,8 @@ def apply_preprocess_passes(model: onnx.ModelProto) -> onnx.ModelProto: :param model: The ONNX model to be preprocessed. :return: A preprocessed ONNX model, ready for quantization. """ - preprocessed_model = onnx.shape_inference.infer_shapes(model) - # The `eliminate_nop_cast` pass should be applied after onnx.shape_inference.infer_shapes() call. + preprocessed_model = SymbolicShapeInference.infer_shapes(model) + # The `eliminate_nop_cast` pass should be applied after `SymbolicShapeInference.infer_shapes()` call. # Otherwise, not all no-op Cast nodes will be found. preprocessed_model = eliminate_nop_cast(preprocessed_model) return preprocessed_model diff --git a/tests/onnx/data/reference_graphs/original_nncf_graph/synthetic/unified_embedding_model.dot b/tests/onnx/data/reference_graphs/original_nncf_graph/synthetic/unified_embedding_model.dot index d347156730d..6a45e8b6bd3 100644 --- a/tests/onnx/data/reference_graphs/original_nncf_graph/synthetic/unified_embedding_model.dot +++ b/tests/onnx/data/reference_graphs/original_nncf_graph/synthetic/unified_embedding_model.dot @@ -10,8 +10,9 @@ strict digraph { "0 Cast" -> "1 Embedding" [label="[1, 3]", style=dashed]; "1 Embedding" -> "4 Concat" [label="[1, 3, 5]", style=solid]; "2 MatMul_1" -> "3 Reshape" [label="[3, 1, 5]", style=solid]; -"4 Concat" -> "5 MatMul_2" [label="[]", style=solid]; -"5 MatMul_2" -> "7 nncf_model_output_0" [label="[1, 6]", style=solid]; +"3 Reshape" -> "4 Concat" [label="[1, 3, 5]", style=solid]; +"4 Concat" -> "5 MatMul_2" [label="[1, 6, 5]", style=solid]; +"5 MatMul_2" -> "7 nncf_model_output_0" [label="[1, 6, 1]", style=solid]; "6 nncf_model_input_0" -> "0 Cast" [label="[1, 3]", style=solid]; "6 nncf_model_input_0" -> "2 MatMul_1" [label="[1, 3]", style=solid]; } diff --git a/tests/onnx/models.py b/tests/onnx/models.py index 4ab7f6e7c4c..dffa2622d0d 100644 --- a/tests/onnx/models.py +++ b/tests/onnx/models.py @@ -1709,8 +1709,8 @@ def __init__(self): reshape_tensor_name = "R" reshape_tensor = create_initializer_tensor( name=reshape_tensor_name, - tensor_array=np.array([1, 3, 5]).astype(np.float32), - data_type=onnx.TensorProto.FLOAT, + tensor_array=np.array([1, 3, 5], dtype=np.int64), + data_type=onnx.TensorProto.INT64, ) reshape_output_name = "Reshape_Y" reshape_node = onnx.helper.make_node( @@ -1726,13 +1726,13 @@ def __init__(self): op_type="Concat", inputs=[embedding_output_name, reshape_output_name], outputs=[concat_output_name], - axis=0, + axis=1, ) matmul_2_tensor_name = "W_2" matmul_2_tensor = create_initializer_tensor( name=matmul_2_tensor_name, - tensor_array=rng.uniform(0, 1, (1, 5)).astype(np.float32), + tensor_array=rng.uniform(0, 1, (1, 5)).astype(np.float32).T, data_type=onnx.TensorProto.FLOAT, ) matmul_2_node = onnx.helper.make_node(