Skip to content

Commit 6a16b6f

Browse files
committed
[Backend Tester] Skip tests with undelegated conv3d ops
ghstack-source-id: 5391bbd ghstack-comment-id: 3276803008 Pull-Request: #14185
1 parent def410a commit 6a16b6f

File tree

4 files changed

+27
-4
lines changed

4 files changed

+27
-4
lines changed

backends/test/suite/flows/vulkan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def _create_vulkan_flow_base(
2020
tester_factory=VulkanTester,
2121
quantize=quantize_stage_factory is not None,
2222
quantize_stage_factory=quantize_stage_factory,
23-
skip_patterns=["float16", "float64"], # Not supported in swiftshader
23+
skip_patterns=["float16", "float64"], # Not supported in swiftshader
2424
)
2525

2626

backends/test/suite/models/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ def wrapped_test(self):
5353
}
5454
with TestContext(test_name, test_func.__name__, flow.name, params):
5555
if flow.should_skip_test(test_name):
56-
raise unittest.SkipTest(f"Skipping test due to matching flow {flow.name} skip patterns")
56+
raise unittest.SkipTest(
57+
f"Skipping test due to matching flow {flow.name} skip patterns"
58+
)
5759

5860
test_func(self, flow, dtype, use_dynamic_shapes)
5961

backends/test/suite/operators/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,9 @@ def _make_wrapped_test(
9898
def wrapped_test(self):
9999
with TestContext(test_name, test_base_name, flow.name, params):
100100
if flow.should_skip_test(test_name):
101-
raise unittest.SkipTest(f"Skipping test due to matching flow {flow.name} skip patterns")
101+
raise unittest.SkipTest(
102+
f"Skipping test due to matching flow {flow.name} skip patterns"
103+
)
102104

103105
test_kwargs = copy.copy(params) or {}
104106
test_kwargs["flow"] = flow

backends/test/suite/runner.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
TestResult,
3636
)
3737
from executorch.exir import EdgeProgramManager
38+
from executorch.exir.dialects._ops import ops as exir_ops
3839

3940

4041
# A list of all runnable test suites and the corresponding python package.
@@ -44,6 +45,24 @@
4445
}
4546

4647

48+
def _graph_has_unsupported_patterns(program: torch.export.ExportedProgram) -> bool:
49+
# Returns true if the model contains patterns that will fail when running on the ET
50+
# portable kernel library.
51+
52+
# Check for 3d convolutions. All convs (1d, 2d, 3d) use the same op, so we need to look at
53+
# the input meta to determine the rank.
54+
for node in program.graph.nodes:
55+
if (
56+
node.op == "call_function"
57+
and node.target == exir_ops.edge.aten.convolution.default
58+
):
59+
in_rank = node.args[0].meta["val"].dim()
60+
if in_rank != 4:
61+
return True
62+
63+
return False
64+
65+
4766
def _get_test_seed(test_base_name: str) -> int:
4867
# Set the seed based on the test base name to give consistent inputs between backends. Add the
4968
# run seed to allow for reproducible results, but still allow for run-to-run variation.
@@ -163,7 +182,7 @@ def build_result(
163182
# Check if any undelegated ops are in the unsupported ops set.
164183
has_unsupported_ops = any(
165184
op in UNSUPPORTED_PORTABLE_OPS for op in undelegated_op_counts.keys()
166-
)
185+
) or _graph_has_unsupported_patterns(edge_manager._etrecord.edge_dialect_program)
167186

168187
# Skip the test if there are unsupported portable ops remaining.
169188
if has_unsupported_ops:

0 commit comments

Comments
 (0)