diff --git a/backends/test/suite/models/__init__.py b/backends/test/suite/models/__init__.py index 06c1c537477..76b2d2966f6 100644 --- a/backends/test/suite/models/__init__.py +++ b/backends/test/suite/models/__init__.py @@ -119,7 +119,7 @@ def run_model_test( flow, context.test_name, context.test_base_name, - 0, # subtest_index - currently unused for model tests + 0, # subtest_index - currently unused for model tests context.params, dynamic_shapes=dynamic_shapes, ) diff --git a/backends/test/suite/reporting.py b/backends/test/suite/reporting.py index f4a1f9a653e..ce8a48dcc12 100644 --- a/backends/test/suite/reporting.py +++ b/backends/test/suite/reporting.py @@ -207,6 +207,8 @@ def is_delegated(self): @dataclass class TestSessionState: + seed: int + # True if the CSV header has been written to report__path. has_written_report_header: bool = False @@ -291,11 +293,17 @@ def count_ops(program: dict[str, ExportedProgram] | ExportedProgram) -> Counter: ) -def begin_test_session(report_path: str | None): +def begin_test_session(report_path: str | None, seed: int): global _active_session assert _active_session is None, "A test session is already active." - _active_session = TestSessionState(report_path=report_path) + _active_session = TestSessionState(report_path=report_path, seed=seed) + + +def get_active_test_session() -> TestSessionState | None: + global _active_session + + return _active_session def log_test_summary(summary: TestCaseSummary): diff --git a/backends/test/suite/runner.py b/backends/test/suite/runner.py index eea1ce6b404..6caf27afe92 100644 --- a/backends/test/suite/runner.py +++ b/backends/test/suite/runner.py @@ -1,5 +1,7 @@ import argparse +import hashlib import importlib +import random import re import time import unittest @@ -26,6 +28,7 @@ begin_test_session, complete_test_session, count_ops, + get_active_test_session, RunSummary, TestCaseSummary, TestResult, @@ -40,6 +43,25 @@ } +def _get_test_seed(test_base_name: str) -> int: + # Set the seed based on the test base name to give consistent inputs between backends. Add the + # run seed to allow for reproducible results, but still allow for run-to-run variation. + # Having a stable hash between runs and across machines is a plus (builtin python hash is not). + # Using MD5 here because it's fast and we don't actually care about cryptographic properties. + test_session = get_active_test_session() + run_seed = ( + test_session.seed + if test_session is not None + else random.randint(0, 100_000_000) + ) + + hasher = hashlib.md5() + data = test_base_name.encode("utf-8") + hasher.update(data) + # Torch doesn't like very long seeds. + return (int.from_bytes(hasher.digest(), "little") % 100_000_000) + run_seed + + def run_test( # noqa: C901 model: torch.nn.Module, inputs: Any, @@ -59,6 +81,8 @@ def run_test( # noqa: C901 error_statistics: list[ErrorStatistics] = [] extra_stats = {} + torch.manual_seed(_get_test_seed(test_base_name)) + # Helper method to construct the summary. def build_result( result: TestResult, error: Exception | None = None @@ -237,6 +261,12 @@ def parse_args(): help="A file to write the test report to, in CSV format.", default="backend_test_report.csv", ) + parser.add_argument( + "--seed", + nargs="?", + help="The numeric seed value to use for random generation.", + type=int, + ) return parser.parse_args() @@ -254,7 +284,10 @@ def runner_main(): # lot of log spam. We don't really need the warning here. warnings.simplefilter("ignore", category=FutureWarning) - begin_test_session(args.report) + seed = args.seed or random.randint(0, 100_000_000) + print(f"Running with seed {seed}.") + + begin_test_session(args.report, seed=seed) if len(args.suite) > 1: raise NotImplementedError("TODO Support multiple suites.")