-
Notifications
You must be signed in to change notification settings - Fork 687
[Backend Tester] Migrate to pytest #14456
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 8 commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
a6c9a30
Update
GregoryComer ee307af
Update
GregoryComer 96c85c0
Update
GregoryComer 766d050
Update
GregoryComer 09a7c73
Update
GregoryComer ad41c6d
Update
GregoryComer 6e6216e
Update
GregoryComer 32f66ca
Update
GregoryComer fe93dab
Update
GregoryComer 5f9ed41
Update
GregoryComer 7b4b989
Update
GregoryComer c46febf
Update
GregoryComer 7f4fe99
Update
GregoryComer 7d4c6b4
Update
GregoryComer 6947827
Update
GregoryComer 204dd3e
Update
GregoryComer cc35eca
Update
GregoryComer 0ab28f4
Update
GregoryComer 175fd38
Update
GregoryComer 36bbc15
Update
GregoryComer e284c48
Update
GregoryComer 44256eb
Update
GregoryComer be3495a
Update
GregoryComer c6e3dc5
Update
GregoryComer 57a93ac
Update
GregoryComer 877eb50
Update
GregoryComer 3129051
Update
GregoryComer c43ca02
Update
GregoryComer bdf69cd
Update
GregoryComer 92eac20
Update
GregoryComer File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ | |
import os | ||
|
||
import executorch.backends.test.suite.flow | ||
import torch | ||
|
||
from executorch.backends.test.suite.flow import TestFlow | ||
from executorch.backends.test.suite.runner import runner_main | ||
|
@@ -55,6 +56,11 @@ def get_test_flows() -> dict[str, TestFlow]: | |
return _ALL_TEST_FLOWS | ||
|
||
|
||
def dtype_to_str(dtype: torch.dtype) -> str: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This utility function is used for generating the display name for parameterized tests. |
||
# Strip off "torch." | ||
return str(dtype)[6:] | ||
|
||
|
||
def load_tests(loader, suite, pattern): | ||
package_dir = os.path.dirname(__file__) | ||
discovered_suite = loader.discover( | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
from typing import Any | ||
|
||
import pytest | ||
import torch | ||
|
||
from executorch.backends.test.suite.flow import all_flows | ||
from executorch.backends.test.suite.reporting import _sum_op_counts | ||
from executorch.backends.test.suite.runner import run_test | ||
|
||
|
||
def pytest_configure(config): | ||
backends = set() | ||
|
||
for flow in all_flows().values(): | ||
config.addinivalue_line( | ||
"markers", | ||
f"flow_{flow.name}: mark a test as testing the {flow.name} flow", | ||
) | ||
|
||
if flow.backend not in backends: | ||
config.addinivalue_line( | ||
"markers", | ||
f"backend_{flow.backend}: mark a test as testing the {flow.backend} backend", | ||
) | ||
backends.add(flow.backend) | ||
|
||
|
||
class TestRunner: | ||
def __init__(self, flow, test_name, test_base_name): | ||
self._flow = flow | ||
self._test_name = test_name | ||
self._test_base_name = test_base_name | ||
self._subtest = 0 | ||
self._results = [] | ||
|
||
def lower_and_run_model( | ||
self, | ||
model: torch.nn.Module, | ||
inputs: Any, | ||
generate_random_test_inputs=True, | ||
dynamic_shapes=None, | ||
): | ||
run_summary = run_test( | ||
model, | ||
inputs, | ||
self._flow, | ||
self._test_name, | ||
self._test_base_name, | ||
self._subtest, | ||
None, | ||
generate_random_test_inputs=generate_random_test_inputs, | ||
dynamic_shapes=dynamic_shapes, | ||
) | ||
|
||
self._subtest += 1 | ||
self._results.append(run_summary) | ||
|
||
if not run_summary.result.is_success(): | ||
if run_summary.result.is_backend_failure(): | ||
raise RuntimeError("Test failure.") from run_summary.error | ||
else: | ||
# Non-backend failure indicates a bad test. Mark as skipped. | ||
pytest.skip( | ||
f"Test failed for reasons other than backend failure. Error: {run_summary.error}" | ||
) | ||
|
||
|
||
@pytest.fixture( | ||
params=[ | ||
pytest.param( | ||
f, | ||
marks=[ | ||
getattr(pytest.mark, f"flow_{f.name}"), | ||
getattr(pytest.mark, f"backend_{f.backend}"), | ||
], | ||
) | ||
for f in all_flows().values() | ||
], | ||
ids=str, | ||
) | ||
def test_runner(request): | ||
return TestRunner(request.param, request.node.name, request.node.originalname) | ||
|
||
|
||
@pytest.hookimpl(optionalhook=True) | ||
def pytest_json_runtest_metadata(item, call): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For |
||
metadata = {"subtests": []} | ||
|
||
if hasattr(item, "funcargs") and "test_runner" in item.funcargs: | ||
runner_instance = item.funcargs["test_runner"] | ||
|
||
for record in runner_instance._results: | ||
subtest_metadata = {} | ||
|
||
error_message = "" | ||
if record.error is not None: | ||
error_str = str(record.error) | ||
if len(error_str) > 400: | ||
error_message = error_str[:200] + "..." + error_str[-200:] | ||
else: | ||
error_message = error_str | ||
|
||
subtest_metadata["Test ID"] = record.name | ||
subtest_metadata["Test Case"] = record.base_name | ||
subtest_metadata["Subtest"] = record.subtest_index | ||
subtest_metadata["Flow"] = record.flow | ||
subtest_metadata["Result"] = record.result.to_short_str() | ||
subtest_metadata["Result Detail"] = record.result.to_detail_str() | ||
subtest_metadata["Error"] = error_message | ||
subtest_metadata["Delegated"] = "True" if record.is_delegated() else "False" | ||
subtest_metadata["Quantize Time (s)"] = ( | ||
f"{record.quantize_time.total_seconds():.3f}" | ||
if record.quantize_time | ||
else None | ||
) | ||
subtest_metadata["Lower Time (s)"] = ( | ||
f"{record.lower_time.total_seconds():.3f}" | ||
if record.lower_time | ||
else None | ||
) | ||
|
||
for output_idx, error_stats in enumerate(record.tensor_error_statistics): | ||
subtest_metadata[f"Output {output_idx} Error Max"] = ( | ||
f"{error_stats.error_max:.3f}" | ||
) | ||
subtest_metadata[f"Output {output_idx} Error MAE"] = ( | ||
f"{error_stats.error_mae:.3f}" | ||
) | ||
subtest_metadata[f"Output {output_idx} SNR"] = f"{error_stats.sqnr:.3f}" | ||
|
||
subtest_metadata["Delegated Nodes"] = _sum_op_counts( | ||
record.delegated_op_counts | ||
) | ||
subtest_metadata["Undelegated Nodes"] = _sum_op_counts( | ||
record.undelegated_op_counts | ||
) | ||
if record.delegated_op_counts: | ||
subtest_metadata["Delegated Ops"] = dict(record.delegated_op_counts) | ||
if record.undelegated_op_counts: | ||
subtest_metadata["Undelegated Ops"] = dict(record.undelegated_op_counts) | ||
subtest_metadata["PTE Size (Kb)"] = ( | ||
f"{record.pte_size_bytes / 1000.0:.3f}" if record.pte_size_bytes else "" | ||
) | ||
|
||
metadata["subtests"].append(subtest_metadata) | ||
|
||
return metadata |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI we don't exactly love pytest when not working in OSS.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have concerns with using pytest? Or suggested changes?