Skip to content

Commit 074b942

Browse files
committed
[Backend Tester] Skip tests with undelegated conv3d ops
ghstack-source-id: 03b0448 ghstack-comment-id: 3276803008 Pull-Request: #14185
1 parent bafc692 commit 074b942

File tree

1 file changed

+21
-1
lines changed

1 file changed

+21
-1
lines changed

backends/test/suite/runner.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
UNSUPPORTED_PORTABLE_OPS = {
1616
"aten::_embedding_bag",
1717
"aten::_adaptive_avg_pool2d",
18+
"aten::adaptive_max_pool2d",
1819
"aten::median",
1920
"aten::median.dim",
2021
"aten::round.decimals",
@@ -34,6 +35,7 @@
3435
TestResult,
3536
)
3637
from executorch.exir import EdgeProgramManager
38+
from executorch.exir.dialects._ops import ops as exir_ops
3739

3840

3941
# A list of all runnable test suites and the corresponding python package.
@@ -43,6 +45,24 @@
4345
}
4446

4547

48+
def _graph_has_unsupported_patterns(program: torch.export.ExportedProgram) -> bool:
49+
# Returns true if the model contains patterns that will fail when running on the ET
50+
# portable kernel library.
51+
52+
# Check for 3d convolutions. All convs (1d, 2d, 3d) use the same op, so we need to look at
53+
# the input meta to determine the rank.
54+
for node in program.graph.nodes:
55+
if (
56+
node.op == "call_function"
57+
and node.target == exir_ops.edge.aten.convolution.default
58+
):
59+
in_rank = node.args[0].meta["val"].dim()
60+
if in_rank != 4:
61+
return True
62+
63+
return False
64+
65+
4666
def _get_test_seed(test_base_name: str) -> int:
4767
# Set the seed based on the test base name to give consistent inputs between backends. Add the
4868
# run seed to allow for reproducible results, but still allow for run-to-run variation.
@@ -162,7 +182,7 @@ def build_result(
162182
# Check if any undelegated ops are in the unsupported ops set.
163183
has_unsupported_ops = any(
164184
op in UNSUPPORTED_PORTABLE_OPS for op in undelegated_op_counts.keys()
165-
)
185+
) or _graph_has_unsupported_patterns(edge_manager._etrecord.edge_dialect_program)
166186

167187
# Skip the test if there are unsupported portable ops remaining.
168188
if has_unsupported_ops:

0 commit comments

Comments
 (0)