Skip to content

Commit 3137c45

Browse files
authored
[Torch FX] Fix "None" Tensor Shape In NNCFGraph Edges (#3747)
### Changes Modify nncf graph builder for torch fx to obtain the tensor shape for edges for some metatypes like avgpool1d, median metatype etc. These are cases where there are multiple output nodes. ### Reason for changes Fix failing tests in executorch openvino backend. ### Related tickets 176898 PTQ Conformance test #756: Pass
1 parent 148c3f1 commit 3137c45

File tree

5 files changed

+91
-5
lines changed

5 files changed

+91
-5
lines changed

src/nncf/experimental/torch/fx/nncf_graph_builder.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,15 @@ def _map_fx_unique_metatypes(node: torch.fx.Node, metatype: om.OperatorMetatype)
7575
:param model: Target GraphModule instance.
7676
:return: Correct FX metatype of the given node if it is exist or the original node metatype otherwise.
7777
"""
78-
if metatype in [om.PTEmbeddingMetatype]:
79-
weight_node = node.args[0]
78+
PT_METATYPE_TO_FX_METATYPE_MAPPING = {
79+
om.PTEmbeddingMetatype: om.PTAtenEmbeddingMetatype,
80+
om.PTEmbeddingBagMetatype: om.PTAtenEmbeddingBagMetatype,
81+
}
82+
if metatype in PT_METATYPE_TO_FX_METATYPE_MAPPING:
83+
fx_metatype = PT_METATYPE_TO_FX_METATYPE_MAPPING[metatype]
84+
weight_node = node.args[fx_metatype.weight_port_ids[0]]
8085
if weight_node.op == "get_attr":
81-
return om.PTAtenEmbeddingMetatype
86+
return fx_metatype
8287

8388
return metatype
8489

@@ -137,6 +142,7 @@ def create_nncf_graph(model: torch.fx.GraphModule) -> PTNNCFGraph:
137142
for source_node in model.graph.nodes:
138143
node_type, node_metatype = GraphConverter.get_node_type_and_metatype(source_node, model)
139144
node_metatype = GraphConverter._map_fx_unique_metatypes(source_node, node_metatype)
145+
140146
is_shared_node = source_node.op in ("get_attr",) and (
141147
const_targets_counter[source_node.target] > 1 or len(source_node.users) > 1
142148
)
@@ -190,7 +196,16 @@ def get_edge_params(
190196
source_node.meta["val"], (tuple, list)
191197
):
192198
tensor = source_node.meta["val"][0]
193-
elif source_nncf_node.metatype in [om.PTSplitMetatype, om.PTMaxMetatype, om.PTMinMetatype]:
199+
elif source_nncf_node.metatype in [
200+
om.PTSplitMetatype,
201+
om.PTMaxMetatype,
202+
om.PTMinMetatype,
203+
om.PTMedianMetatype,
204+
om.PTAdaptiveMaxPool1dMetatype,
205+
om.PTAdaptiveMaxPool2dMetatype,
206+
om.PTAdaptiveMaxPool3dMetatype,
207+
om.PTAtenEmbeddingBagMetatype,
208+
] and isinstance(source_node.meta["val"], (tuple, list)):
194209
tensor = source_node.meta["val"][output_idx]
195210
# Assume every outputs corresponds to an unique output_port_id
196211
output_port_id = output_idx

src/nncf/torch/graph/operator_metatypes.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -699,6 +699,12 @@ class PTMeanMetatype(PTOperatorMetatype):
699699
hw_config_names = [HWConfigOpName.REDUCEMEAN]
700700

701701

702+
@PT_OPERATOR_METATYPES.register()
703+
class PTMedianMetatype(PTOperatorMetatype):
704+
name = "MedianOp"
705+
module_to_function_names = {NamespaceTarget.ATEN: ["median"]}
706+
707+
702708
@PT_OPERATOR_METATYPES.register()
703709
class PTRoundMetatype(PTOperatorMetatype):
704710
name = "RoundOp"
@@ -745,6 +751,16 @@ class PTBatchNormMetatype(PTOperatorMetatype):
745751
bias_port_id = 2
746752

747753

754+
@PT_OPERATOR_METATYPES.register()
755+
class PTAvgPool1dMetatype(PTOperatorMetatype):
756+
name = "AvgPool1DOp"
757+
module_to_function_names = {
758+
NamespaceTarget.TORCH_NN_FUNCTIONAL: ["avg_pool1d", "adaptive_avg_pool1d"],
759+
NamespaceTarget.ATEN: ["adaptive_avg_pool1d"],
760+
}
761+
hw_config_names = [HWConfigOpName.AVGPOOL]
762+
763+
748764
@PT_OPERATOR_METATYPES.register()
749765
class PTAvgPool2dMetatype(PTOperatorMetatype):
750766
name = "AvgPool2DOp"
@@ -759,6 +775,7 @@ class PTAvgPool3dMetatype(PTOperatorMetatype):
759775
hw_config_names = [HWConfigOpName.AVGPOOL]
760776

761777

778+
@PT_OPERATOR_METATYPES.register()
762779
class PTAdaptiveMaxPool1dMetatype(PTOperatorMetatype):
763780
name = "AdaptiveMaxPool1DOp"
764781
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["adaptive_max_pool1d"]}
@@ -965,6 +982,14 @@ class PTEmbeddingBagMetatype(PTOperatorMetatype):
965982
weight_port_ids = [1]
966983

967984

985+
@FX_OPERATOR_METATYPES.register()
986+
class PTAtenEmbeddingBagMetatype(OperatorMetatype):
987+
name = "EmbeddingBagOp"
988+
module_to_function_names = {NamespaceTarget.ATEN: ["embedding_bag"]}
989+
hw_config_names = [HWConfigOpName.EMBEDDINGBAG]
990+
weight_port_ids = [0]
991+
992+
968993
@PT_OPERATOR_METATYPES.register()
969994
class PTSoftmaxMetatype(PTOperatorMetatype):
970995
name = "SoftmaxOp"
@@ -1222,6 +1247,7 @@ def get_operator_metatypes() -> list[type[OperatorMetatype]]:
12221247
OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS = [
12231248
PTEmbeddingMetatype,
12241249
PTEmbeddingBagMetatype,
1250+
PTAtenEmbeddingBagMetatype,
12251251
PTModuleEmbeddingBagMetatype,
12261252
PTModuleEmbeddingMetatype,
12271253
]
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
strict digraph {
2+
"0 embedding_weight" [id=0, type="get_attr"];
3+
"1 embeddingbag_weight" [id=1, type="get_attr"];
4+
"2 x" [id=2, type=input];
5+
"3 embedding" [id=3, type=embedding];
6+
"4 arange" [id=4, type=arange];
7+
"5 reshape" [id=5, type=reshape];
8+
"6 embedding_bag" [id=6, type="embedding_bag"];
9+
"7 getitem" [id=7, type="__getitem__"];
10+
"8 getitem_1" [id=8, type="__getitem__"];
11+
"9 getitem_2" [id=9, type="__getitem__"];
12+
"10 getitem_3" [id=10, type="__getitem__"];
13+
"11 add" [id=11, type=add];
14+
"12 output" [id=12, type=output];
15+
"0 embedding_weight" -> "3 embedding" [style=solid, label="(10, 10)"];
16+
"1 embeddingbag_weight" -> "6 embedding_bag" [style=solid, label="(10, 10)"];
17+
"2 x" -> "3 embedding" [style=solid, label="(1, 1)"];
18+
"2 x" -> "5 reshape" [style=solid, label="(1, 1)"];
19+
"3 embedding" -> "11 add" [style=solid, label="(1, 1, 10)"];
20+
"4 arange" -> "6 embedding_bag" [style=solid, label="(1,)"];
21+
"5 reshape" -> "6 embedding_bag" [style=solid, label="(1,)"];
22+
"6 embedding_bag" -> "7 getitem" [style=solid, label="(1, 10)"];
23+
"6 embedding_bag" -> "8 getitem_1" [style=solid, label="(1,)"];
24+
"6 embedding_bag" -> "9 getitem_2" [style=solid, label="(1,)"];
25+
"6 embedding_bag" -> "10 getitem_3" [style=solid, label="(1,)"];
26+
"7 getitem" -> "11 add" [style=solid, label="(1, 10)"];
27+
"11 add" -> "12 output" [style=solid, label="(1, 1, 10)"];
28+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
{
2+
"embedding_weight": "PTConstNoopMetatype",
3+
"embeddingbag_weight": "PTConstNoopMetatype",
4+
"x": "PTInputNoopMetatype",
5+
"embedding": "PTAtenEmbeddingMetatype",
6+
"arange": "UnknownMetatype",
7+
"reshape": "PTReshapeMetatype",
8+
"embedding_bag": "PTAtenEmbeddingBagMetatype",
9+
"getitem": "PTGatherMetatype",
10+
"getitem_1": "PTGatherMetatype",
11+
"getitem_2": "PTGatherMetatype",
12+
"getitem_3": "PTGatherMetatype",
13+
"add": "PTAddMetatype",
14+
"output": "PTOutputNoopMetatype"
15+
}

tests/torch2/fx/test_models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from tests.cross_fw.shared.nx_graph import compare_nx_graph_with_reference
4141
from tests.cross_fw.shared.paths import TEST_ROOT
4242
from tests.torch import test_models
43+
from tests.torch.test_models.synthetic import EmbeddingSumModel
4344
from tests.torch.test_models.synthetic import MultiBranchesConnectedModel
4445
from tests.torch.test_models.synthetic import ShortTransformer
4546
from tests.torch.test_models.synthetic import YOLO11N_SDPABlock
@@ -75,6 +76,7 @@ def torchvision_model_case(model_id: str, input_shape: tuple[int,]):
7576
ModelCase(test_models.UNet, "unet", [1, 3, 224, 224]),
7677
ModelCase(partial(ShortTransformer, 5, 10), "synthetic_transformer", [5]),
7778
ModelCase(YOLO11N_SDPABlock, "yolo11n_sdpa_block", YOLO11N_SDPABlock.INPUT_SIZE),
79+
ModelCase(EmbeddingSumModel, "embedding_bag_model", [1, 1]),
7880
)
7981

8082

@@ -121,7 +123,7 @@ def test_model(test_case: ModelCase, regen_ref_data: bool):
121123
model = test_case.model_builder()
122124
model.to(device)
123125

124-
dtype = torch.int32 if test_case.model_id == "synthetic_transformer" else torch.float32
126+
dtype = torch.int32 if test_case.model_id in ["synthetic_transformer", "embedding_bag_model"] else torch.float32
125127
ex_input = torch.ones(test_case.input_shape, dtype=dtype)
126128
exported_model = get_torch_fx_model(model, ex_input)
127129
nncf_graph = GraphConverter.create_nncf_graph(exported_model)

0 commit comments

Comments
 (0)