Skip to content

Commit 382a0aa

Browse files
committed
[Backend Tester] Seed based on test name
ghstack-source-id: f123042 ghstack-comment-id: 3177836622 Pull-Request: #13313
1 parent 459485d commit 382a0aa

File tree

3 files changed

+45
-4
lines changed

3 files changed

+45
-4
lines changed

backends/test/suite/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def run_model_test(
119119
flow,
120120
context.test_name,
121121
context.test_base_name,
122-
0, # subtest_index - currently unused for model tests
122+
0, # subtest_index - currently unused for model tests
123123
context.params,
124124
dynamic_shapes=dynamic_shapes,
125125
)

backends/test/suite/reporting.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,8 @@ def is_delegated(self):
207207

208208
@dataclass
209209
class TestSessionState:
210+
seed: int
211+
210212
# True if the CSV header has been written to report__path.
211213
has_written_report_header: bool = False
212214

@@ -291,11 +293,17 @@ def count_ops(program: dict[str, ExportedProgram] | ExportedProgram) -> Counter:
291293
)
292294

293295

294-
def begin_test_session(report_path: str | None):
296+
def begin_test_session(report_path: str | None, seed: int):
295297
global _active_session
296298

297299
assert _active_session is None, "A test session is already active."
298-
_active_session = TestSessionState(report_path=report_path)
300+
_active_session = TestSessionState(report_path=report_path, seed=seed)
301+
302+
303+
def get_active_test_session() -> TestSessionState | None:
304+
global _active_session
305+
306+
return _active_session
299307

300308

301309
def log_test_summary(summary: TestCaseSummary):

backends/test/suite/runner.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import argparse
2+
import hashlib
23
import importlib
4+
import random
35
import re
46
import time
57
import unittest
@@ -26,6 +28,7 @@
2628
begin_test_session,
2729
complete_test_session,
2830
count_ops,
31+
get_active_test_session,
2932
RunSummary,
3033
TestCaseSummary,
3134
TestResult,
@@ -40,6 +43,25 @@
4043
}
4144

4245

46+
def _get_test_seed(test_base_name: str) -> int:
47+
# Set the seed based on the test base name to give consistent inputs between backends. Add the
48+
# run seed to allow for reproducible results, but still allow for run-to-run variation.
49+
# Having a stable hash between runs and across machines is a plus (builtin python hash is not).
50+
# Using MD5 here because it's fast and we don't actually care about cryptographic properties.
51+
test_session = get_active_test_session()
52+
run_seed = (
53+
test_session.seed
54+
if test_session is not None
55+
else random.randint(0, 100_000_000)
56+
)
57+
58+
hasher = hashlib.md5()
59+
data = test_base_name.encode("utf-8")
60+
hasher.update(data)
61+
# Torch doesn't like very long seeds.
62+
return (int.from_bytes(hasher.digest(), "little") % 100_000_000) + run_seed
63+
64+
4365
def run_test( # noqa: C901
4466
model: torch.nn.Module,
4567
inputs: Any,
@@ -59,6 +81,8 @@ def run_test( # noqa: C901
5981
error_statistics: list[ErrorStatistics] = []
6082
extra_stats = {}
6183

84+
torch.manual_seed(_get_test_seed(test_base_name))
85+
6286
# Helper method to construct the summary.
6387
def build_result(
6488
result: TestResult, error: Exception | None = None
@@ -237,6 +261,12 @@ def parse_args():
237261
help="A file to write the test report to, in CSV format.",
238262
default="backend_test_report.csv",
239263
)
264+
parser.add_argument(
265+
"--seed",
266+
nargs="?",
267+
help="The numeric seed value to use for random generation.",
268+
type=int,
269+
)
240270
return parser.parse_args()
241271

242272

@@ -254,7 +284,10 @@ def runner_main():
254284
# lot of log spam. We don't really need the warning here.
255285
warnings.simplefilter("ignore", category=FutureWarning)
256286

257-
begin_test_session(args.report)
287+
seed = args.seed or random.randint(0, 100_000_000)
288+
print(f"Running with seed {seed}.")
289+
290+
begin_test_session(args.report, seed=seed)
258291

259292
if len(args.suite) > 1:
260293
raise NotImplementedError("TODO Support multiple suites.")

0 commit comments

Comments
 (0)