diff --git a/backends/test/suite/context.py b/backends/test/suite/context.py index 16b22b89f87..fd754737060 100644 --- a/backends/test/suite/context.py +++ b/backends/test/suite/context.py @@ -1,6 +1,8 @@ # Test run context management. This is used to determine the test context for reporting # purposes. class TestContext: + subtest_index: int + def __init__( self, test_name: str, test_base_name: str, flow_name: str, params: dict | None ): @@ -8,6 +10,7 @@ def __init__( self.test_base_name = test_base_name self.flow_name = flow_name self.params = params + self.subtest_index = 0 def __enter__(self): global _active_test_context diff --git a/backends/test/suite/models/__init__.py b/backends/test/suite/models/__init__.py index 700baa435fc..06c1c537477 100644 --- a/backends/test/suite/models/__init__.py +++ b/backends/test/suite/models/__init__.py @@ -119,6 +119,7 @@ def run_model_test( flow, context.test_name, context.test_base_name, + 0, # subtest_index - currently unused for model tests context.params, dynamic_shapes=dynamic_shapes, ) diff --git a/backends/test/suite/operators/__init__.py b/backends/test/suite/operators/__init__.py index 8f7fbb1bc03..6ceb9086f71 100644 --- a/backends/test/suite/operators/__init__.py +++ b/backends/test/suite/operators/__init__.py @@ -152,12 +152,16 @@ def _test_op( flow, context.test_name, context.test_base_name, + context.subtest_index, context.params, generate_random_test_inputs=generate_random_test_inputs, ) log_test_summary(run_summary) + # This is reset when a new test is started - it creates the context per-test. + context.subtest_index = context.subtest_index + 1 + if not run_summary.result.is_success(): if run_summary.result.is_backend_failure(): raise RuntimeError("Test failure.") from run_summary.error diff --git a/backends/test/suite/reporting.py b/backends/test/suite/reporting.py index 6294ab9434f..f4a1f9a653e 100644 --- a/backends/test/suite/reporting.py +++ b/backends/test/suite/reporting.py @@ -21,6 +21,7 @@ CSV_FIELD_NAMES = [ "Test ID", "Test Case", + "Subtest", "Flow", "Params", "Result", @@ -163,6 +164,9 @@ class TestCaseSummary: name: str """ The full name of test, including flow and parameter suffixes. """ + subtest_index: int + """ The subtest number. If a test case runs multiple tests, this field can be used to disambiguate. """ + params: dict | None """ Test-specific parameters, such as dtype. """ @@ -356,6 +360,7 @@ def write_csv_row(record: TestCaseSummary, output: TextIO): row = { "Test ID": record.name, "Test Case": record.base_name, + "Subtest": record.subtest_index, "Flow": record.flow, "Params": _serialize_params(record.params), "Result": record.result.to_short_str(), diff --git a/backends/test/suite/runner.py b/backends/test/suite/runner.py index b128d64eca2..4999779b3c9 100644 --- a/backends/test/suite/runner.py +++ b/backends/test/suite/runner.py @@ -45,6 +45,7 @@ def run_test( # noqa: C901 flow: TestFlow, test_name: str, test_base_name: str, + subtest_index: int, params: dict | None, dynamic_shapes: Any | None = None, generate_random_test_inputs: bool = True, @@ -64,6 +65,7 @@ def build_result( return TestCaseSummary( backend=flow.backend, base_name=test_base_name, + subtest_index=subtest_index, flow=flow.name, name=test_name, params=params, diff --git a/backends/test/suite/tests/test_reporting.py b/backends/test/suite/tests/test_reporting.py index 6ab4817b44c..a6f2ca60bdd 100644 --- a/backends/test/suite/tests/test_reporting.py +++ b/backends/test/suite/tests/test_reporting.py @@ -24,6 +24,7 @@ base_name="test1", flow="flow1", name="test1_backend1_flow1", + subtest_index=0, params=None, result=TestResult.SUCCESS, error=None, @@ -34,6 +35,7 @@ base_name="test1", flow="flow1", name="test1_backend2_flow1", + subtest_index=0, params=None, result=TestResult.LOWER_FAIL, error=None, @@ -44,6 +46,7 @@ base_name="test2", flow="flow1", name="test2_backend1_flow1", + subtest_index=0, params={"dtype": torch.float32}, result=TestResult.SUCCESS_UNDELEGATED, error=None, @@ -54,6 +57,7 @@ base_name="test2", flow="flow1", name="test2_backend2_flow1", + subtest_index=0, params={"use_dynamic_shapes": True}, result=TestResult.SKIPPED, error=None,