Skip to content

[Backend Tester] Clean up operator test logic #12736

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 38 commits into from
Jul 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
f120e70
Update
GregoryComer Jul 18, 2025
0fb85e6
Update
GregoryComer Jul 18, 2025
4d8d844
Update
GregoryComer Jul 19, 2025
dc12b40
Update
GregoryComer Jul 21, 2025
ead0616
Update
GregoryComer Jul 22, 2025
0f13676
Update
GregoryComer Jul 22, 2025
b0b01f2
Update
GregoryComer Jul 22, 2025
8b9c9ef
Update
GregoryComer Jul 22, 2025
06bf03a
Update
GregoryComer Jul 22, 2025
2f8f49b
Update
GregoryComer Jul 22, 2025
8ca7766
Update
GregoryComer Jul 22, 2025
bffb95f
Update
GregoryComer Jul 22, 2025
d21492b
Update
GregoryComer Jul 22, 2025
e2c4ea5
Update
GregoryComer Jul 22, 2025
8230848
Update
GregoryComer Jul 22, 2025
2a1f564
Update
GregoryComer Jul 22, 2025
b35e7b1
Update
GregoryComer Jul 22, 2025
5c4c6ce
Update
GregoryComer Jul 22, 2025
9397803
Update
GregoryComer Jul 22, 2025
9dfeb5a
Update
GregoryComer Jul 22, 2025
ff5c4a5
Update
GregoryComer Jul 22, 2025
42a5de5
Update
GregoryComer Jul 22, 2025
402d8f5
Update
GregoryComer Jul 22, 2025
34d3ab3
Update
GregoryComer Jul 22, 2025
482bd21
Update
GregoryComer Jul 22, 2025
7ef236b
Update
GregoryComer Jul 23, 2025
4a58c9d
Update
GregoryComer Jul 23, 2025
81dfb07
Update
GregoryComer Jul 23, 2025
4d50265
Update
GregoryComer Jul 23, 2025
5f66043
Update
GregoryComer Jul 23, 2025
24e919d
Update
GregoryComer Jul 23, 2025
89757ce
Update
GregoryComer Jul 23, 2025
423f79a
Update
GregoryComer Jul 23, 2025
7a2fab5
Update
GregoryComer Jul 23, 2025
033c231
Update
GregoryComer Jul 23, 2025
27cd171
Update
GregoryComer Jul 23, 2025
7bdd3e5
Update
GregoryComer Jul 23, 2025
e2df06e
Update
GregoryComer Jul 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 1 addition & 138 deletions backends/test/suite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,11 @@

import logging
import os
import unittest

from enum import Enum
from typing import Callable

import executorch.backends.test.suite.flow

import torch
from executorch.backends.test.suite.context import get_active_test_context, TestContext
from executorch.backends.test.suite.flow import TestFlow
from executorch.backends.test.suite.reporting import log_test_summary
from executorch.backends.test.suite.runner import run_test, runner_main
from executorch.backends.test.suite.runner import runner_main

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -62,109 +55,6 @@ def get_test_flows() -> dict[str, TestFlow]:
return _ALL_TEST_FLOWS


DTYPES = [
# torch.int8,
# torch.uint8,
# torch.int16,
# torch.uint16,
# torch.int32,
# torch.uint32,
# torch.int64,
# torch.uint64,
# torch.float16,
torch.float32,
# torch.float64,
]

FLOAT_DTYPES = [
torch.float16,
torch.float32,
torch.float64,
]


# The type of test function. This controls the test generation and expected signature.
# Standard tests are run, as is. Dtype tests get a variant generated for each dtype and
# take an additional dtype parameter.
class TestType(Enum):
STANDARD = 1
DTYPE = 2


# Function annotation for dtype tests. This instructs the test framework to run the test
# for each supported dtype and to pass dtype as a test parameter.
def dtype_test(func):
func.test_type = TestType.DTYPE
return func


# Class annotation for operator tests. This triggers the test framework to register
# the tests.
def operator_test(cls):
_create_tests(cls)
return cls


# Generate test cases for each backend flow.
def _create_tests(cls):
for key in dir(cls):
if key.startswith("test_"):
_expand_test(cls, key)


# Expand a test into variants for each registered flow.
def _expand_test(cls, test_name: str):
test_func = getattr(cls, test_name)
for flow in get_test_flows().values():
_create_test_for_backend(cls, test_func, flow)
delattr(cls, test_name)


def _make_wrapped_test(
test_func: Callable,
test_name: str,
flow: TestFlow,
params: dict | None = None,
):
def wrapped_test(self):
with TestContext(test_name, flow.name, params):
test_kwargs = params or {}
test_kwargs["flow"] = flow

test_func(self, **test_kwargs)

wrapped_test._name = test_name
wrapped_test._flow = flow

return wrapped_test


def _create_test_for_backend(
cls,
test_func: Callable,
flow: TestFlow,
):
test_type = getattr(test_func, "test_type", TestType.STANDARD)

if test_type == TestType.STANDARD:
wrapped_test = _make_wrapped_test(test_func, test_func.__name__, flow)
test_name = f"{test_func.__name__}_{flow.name}"
setattr(cls, test_name, wrapped_test)
elif test_type == TestType.DTYPE:
for dtype in DTYPES:
wrapped_test = _make_wrapped_test(
test_func,
test_func.__name__,
flow,
{"dtype": dtype},
)
dtype_name = str(dtype)[6:] # strip "torch."
test_name = f"{test_func.__name__}_{dtype_name}_{flow.name}"
setattr(cls, test_name, wrapped_test)
else:
raise NotImplementedError(f"Unknown test type {test_type}.")


def load_tests(loader, suite, pattern):
package_dir = os.path.dirname(__file__)
discovered_suite = loader.discover(
Expand All @@ -174,32 +64,5 @@ def load_tests(loader, suite, pattern):
return suite


class OperatorTest(unittest.TestCase):
def _test_op(self, model, inputs, flow: TestFlow):
context = get_active_test_context()

# This should be set in the wrapped test. See _make_wrapped_test above.
assert context is not None, "Missing test context."

run_summary = run_test(
model,
inputs,
flow,
context.test_name,
context.params,
)

log_test_summary(run_summary)

if not run_summary.result.is_success():
if run_summary.result.is_backend_failure():
raise RuntimeError("Test failure.") from run_summary.error
else:
# Non-backend failure indicates a bad test. Mark as skipped.
raise unittest.SkipTest(
f"Test failed for reasons other than backend failure. Error: {run_summary.error}"
)


if __name__ == "__main__":
runner_main()
11 changes: 10 additions & 1 deletion backends/test/suite/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,17 @@ def _filter_tests(
def _is_test_enabled(test_case: unittest.TestCase, test_filter: TestFilter) -> bool:
test_method = getattr(test_case, test_case._testMethodName)

# Handle import / discovery failures - leave them enabled to report nicely at the
# top level. There might be a better way to do this. Internally, unittest seems to
# replace it with a stub method to report the failure.
if "testFailure" in str(test_method):
print(f"Warning: Test {test_case._testMethodName} failed to import.")
return True

if not hasattr(test_method, "_flow"):
print(f"Test missing flow: {test_method}")
raise RuntimeError(
f"Test missing flow: {test_case._testMethodName} {test_method}"
)

flow: TestFlow = test_method._flow

Expand Down
141 changes: 141 additions & 0 deletions backends/test/suite/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,17 @@
# pyre-unsafe

import os
import unittest

from enum import Enum
from typing import Callable

import torch
from executorch.backends.test.suite import get_test_flows
from executorch.backends.test.suite.context import get_active_test_context, TestContext
from executorch.backends.test.suite.flow import TestFlow
from executorch.backends.test.suite.reporting import log_test_summary
from executorch.backends.test.suite.runner import run_test


def load_tests(loader, suite, pattern):
Expand All @@ -16,3 +27,133 @@ def load_tests(loader, suite, pattern):
)
suite.addTests(discovered_suite)
return suite


DTYPES = [
# torch.int8,
# torch.uint8,
# torch.int16,
# torch.uint16,
# torch.int32,
# torch.uint32,
# torch.int64,
# torch.uint64,
# torch.float16,
torch.float32,
# torch.float64,
]

FLOAT_DTYPES = [
torch.float16,
torch.float32,
torch.float64,
]


# The type of test function. This controls the test generation and expected signature.
# Standard tests are run, as is. Dtype tests get a variant generated for each dtype and
# take an additional dtype parameter.
class TestType(Enum):
STANDARD = 1
DTYPE = 2


# Function annotation for dtype tests. This instructs the test framework to run the test
# for each supported dtype and to pass dtype as a test parameter.
def dtype_test(func):
func.test_type = TestType.DTYPE
return func


# Class annotation for operator tests. This triggers the test framework to register
# the tests.
def operator_test(cls):
_create_tests(cls)
return cls


# Generate test cases for each backend flow.
def _create_tests(cls):
for key in dir(cls):
if key.startswith("test_"):
_expand_test(cls, key)


# Expand a test into variants for each registered flow.
def _expand_test(cls, test_name: str):
test_func = getattr(cls, test_name)
for flow in get_test_flows().values():
_create_test_for_backend(cls, test_func, flow)
delattr(cls, test_name)


def _make_wrapped_test(
test_func: Callable,
test_name: str,
flow: TestFlow,
params: dict | None = None,
):
def wrapped_test(self):
with TestContext(test_name, flow.name, params):
test_kwargs = params or {}
test_kwargs["flow"] = flow

test_func(self, **test_kwargs)

wrapped_test._name = test_name
wrapped_test._flow = flow

return wrapped_test


def _create_test_for_backend(
cls,
test_func: Callable,
flow: TestFlow,
):
test_type = getattr(test_func, "test_type", TestType.STANDARD)

if test_type == TestType.STANDARD:
wrapped_test = _make_wrapped_test(test_func, test_func.__name__, flow)
test_name = f"{test_func.__name__}_{flow.name}"
setattr(cls, test_name, wrapped_test)
elif test_type == TestType.DTYPE:
for dtype in DTYPES:
wrapped_test = _make_wrapped_test(
test_func,
test_func.__name__,
flow,
{"dtype": dtype},
)
dtype_name = str(dtype)[6:] # strip "torch."
test_name = f"{test_func.__name__}_{dtype_name}_{flow.name}"
setattr(cls, test_name, wrapped_test)
else:
raise NotImplementedError(f"Unknown test type {test_type}.")


class OperatorTest(unittest.TestCase):
def _test_op(self, model, inputs, flow: TestFlow):
context = get_active_test_context()

# This should be set in the wrapped test. See _make_wrapped_test above.
assert context is not None, "Missing test context."

run_summary = run_test(
model,
inputs,
flow,
context.test_name,
context.params,
)

log_test_summary(run_summary)

if not run_summary.result.is_success():
if run_summary.result.is_backend_failure():
raise RuntimeError("Test failure.") from run_summary.error
else:
# Non-backend failure indicates a bad test. Mark as skipped.
raise unittest.SkipTest(
f"Test failed for reasons other than backend failure. Error: {run_summary.error}"
)
8 changes: 6 additions & 2 deletions backends/test/suite/operators/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@


import torch

from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest
from executorch.backends.test.suite.flow import TestFlow

from executorch.backends.test.suite.operators import (
dtype_test,
operator_test,
OperatorTest,
)


class Model(torch.nn.Module):
def forward(self, x, y):
Expand Down
8 changes: 6 additions & 2 deletions backends/test/suite/operators/test_div.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,14 @@
from typing import Optional

import torch

from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest
from executorch.backends.test.suite.flow import TestFlow

from executorch.backends.test.suite.operators import (
dtype_test,
operator_test,
OperatorTest,
)


class Model(torch.nn.Module):
def forward(self, x, y):
Expand Down
8 changes: 6 additions & 2 deletions backends/test/suite/operators/test_elu.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@


import torch

from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest
from executorch.backends.test.suite.flow import TestFlow

from executorch.backends.test.suite.operators import (
dtype_test,
operator_test,
OperatorTest,
)


class Model(torch.nn.Module):
def __init__(self, alpha=1.0, inplace=False):
Expand Down
Loading
Loading