Skip to content

[Backend Tester] Add test name filter #12625

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 29 additions & 11 deletions backends/test/suite/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
import os
import unittest

from dataclasses import dataclass
from types import ModuleType
from typing import Pattern

from executorch.backends.test.suite.flow import TestFlow

Expand All @@ -18,8 +20,19 @@
#


@dataclass
class TestFilter:
"""A set of filters for test discovery."""

backends: set[str] | None
""" The set of backends to include. If None, all backends are included. """

name_regex: Pattern[str] | None
""" A regular expression to filter test names. If None, all tests are included. """


def discover_tests(
root_module: ModuleType, backends: set[str] | None
root_module: ModuleType, test_filter: TestFilter
) -> unittest.TestSuite:
# Collect all tests using the unittest discovery mechanism then filter down.

Expand All @@ -32,32 +45,37 @@ def discover_tests(
module_dir = os.path.dirname(module_file)
suite = loader.discover(module_dir)

return _filter_tests(suite, backends)
return _filter_tests(suite, test_filter)


def _filter_tests(
suite: unittest.TestSuite, backends: set[str] | None
suite: unittest.TestSuite, test_filter: TestFilter
) -> unittest.TestSuite:
# Recursively traverse the test suite and add them to the filtered set.
filtered_suite = unittest.TestSuite()

for child in suite:
if isinstance(child, unittest.TestSuite):
filtered_suite.addTest(_filter_tests(child, backends))
filtered_suite.addTest(_filter_tests(child, test_filter))
elif isinstance(child, unittest.TestCase):
if _is_test_enabled(child, backends):
if _is_test_enabled(child, test_filter):
filtered_suite.addTest(child)
else:
raise RuntimeError(f"Unexpected test type: {type(child)}")

return filtered_suite


def _is_test_enabled(test_case: unittest.TestCase, backends: set[str] | None) -> bool:
def _is_test_enabled(test_case: unittest.TestCase, test_filter: TestFilter) -> bool:
test_method = getattr(test_case, test_case._testMethodName)
flow: TestFlow = test_method._flow

if test_filter.backends is not None and flow.backend not in test_filter.backends:
return False

if test_filter.name_regex is not None and not test_filter.name_regex.search(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feels like we are reimplementing features from unittest or pytest :p

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I've tried to avoid this, but unfortunately, it doesn't look like the unittest package structure is very extensible in the way I need. There aren't a lot of hooks to control reporting / filtering / discovery without writing custom driver code. I'm open to suggestions, but this seemed to be the lowest friction path with unittest. Switching to pytest might be an option, but I'm hoping that I don't need to do much more non-differentiated work like this.

test_case.id()
):
return False

if backends is not None:
flow: TestFlow = test_method._flow
return flow.backend in backends
else:
return True
return True
24 changes: 13 additions & 11 deletions backends/test/suite/runner.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import argparse
import importlib
import re
import unittest

from typing import Callable

import torch

from executorch.backends.test.harness import Tester
from executorch.backends.test.suite.discovery import discover_tests
from executorch.backends.test.suite.discovery import discover_tests, TestFilter
from executorch.backends.test.suite.reporting import (
begin_test_session,
complete_test_session,
Expand Down Expand Up @@ -148,18 +149,17 @@ def parse_args():
parser.add_argument(
"-b", "--backend", nargs="*", help="The backend or backends to test."
)
parser.add_argument(
"-f", "--filter", nargs="?", help="A regular expression filter for test names."
)
return parser.parse_args()


def test(suite):
if isinstance(suite, unittest.TestSuite):
print(f"Suite: {suite}")
for t in suite:
test(t)
else:
print(f"Leaf: {type(suite)} {suite}")
print(f" {suite.__name__}")
print(f" {callable(suite)}")
def build_test_filter(args: argparse.Namespace) -> TestFilter:
return TestFilter(
backends=set(args.backend) if args.backend is not None else None,
name_regex=re.compile(args.filter) if args.filter is not None else None,
)


def runner_main():
Expand All @@ -172,7 +172,9 @@ def runner_main():

test_path = NAMED_SUITES[args.suite[0]]
test_root = importlib.import_module(test_path)
suite = discover_tests(test_root, args.backend)
test_filter = build_test_filter(args)

suite = discover_tests(test_root, test_filter)
unittest.TextTestRunner(verbosity=2).run(suite)

summary = complete_test_session()
Expand Down
Loading