Skip to content

Commit e7b7975

Browse files
committed
Update
[ghstack-poisoned]
1 parent 7ba2a7f commit e7b7975

File tree

3 files changed

+35
-7
lines changed

3 files changed

+35
-7
lines changed

backends/test/suite/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def run_model_test(
119119
flow,
120120
context.test_name,
121121
context.test_base_name,
122-
0, # subtest_index - currently unused for model tests
122+
0, # subtest_index - currently unused for model tests
123123
context.params,
124124
dynamic_shapes=dynamic_shapes,
125125
)

backends/test/suite/reporting.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,8 @@ def is_delegated(self):
207207

208208
@dataclass
209209
class TestSessionState:
210+
seed: int
211+
210212
# True if the CSV header has been written to report__path.
211213
has_written_report_header: bool = False
212214

@@ -291,11 +293,17 @@ def count_ops(program: dict[str, ExportedProgram] | ExportedProgram) -> Counter:
291293
)
292294

293295

294-
def begin_test_session(report_path: str | None):
296+
def begin_test_session(report_path: str | None, seed: int):
295297
global _active_session
296298

297299
assert _active_session is None, "A test session is already active."
298-
_active_session = TestSessionState(report_path=report_path)
300+
_active_session = TestSessionState(report_path=report_path, seed=seed)
301+
302+
303+
def get_active_test_session() -> TestSessionState | None:
304+
global _active_session
305+
306+
return _active_session
299307

300308

301309
def log_test_summary(summary: TestCaseSummary):

backends/test/suite/runner.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import argparse
22
import hashlib
33
import importlib
4+
import random
45
import re
56
import time
67
import unittest
@@ -27,6 +28,7 @@
2728
begin_test_session,
2829
complete_test_session,
2930
count_ops,
31+
get_active_test_session,
3032
RunSummary,
3133
TestCaseSummary,
3234
TestResult,
@@ -42,14 +44,23 @@
4244

4345

4446
def _get_test_seed(test_base_name: str) -> int:
45-
# Set the seed based on the test base name to give consistent inputs between runs and backends.
46-
# Having a stable hash between runs and across machines is a plus (builtin python hash is not).
47+
# Set the seed based on the test base name to give consistent inputs between backends. Add the
48+
# run seed to allow for reproducible results, but still allow for run-to-run variation.
49+
# Having a stable hash between runs and across machines is a plus (builtin python hash is not).
4750
# Using MD5 here because it's fast and we don't actually care about cryptographic properties.
51+
test_session = get_active_test_session()
52+
run_seed = (
53+
test_session.seed
54+
if test_session is not None
55+
else random.randint(0, 100_000_000)
56+
)
57+
4858
hasher = hashlib.md5()
4959
data = test_base_name.encode("utf-8")
5060
hasher.update(data)
5161
# Torch doesn't like very long seeds.
52-
return int.from_bytes(hasher.digest(), "little") % 100_000_000
62+
return (int.from_bytes(hasher.digest(), "little") % 100_000_000) + run_seed
63+
5364

5465
def run_test( # noqa: C901
5566
model: torch.nn.Module,
@@ -250,6 +261,12 @@ def parse_args():
250261
help="A file to write the test report to, in CSV format.",
251262
default="backend_test_report.csv",
252263
)
264+
parser.add_argument(
265+
"--seed",
266+
nargs="?",
267+
help="The numeric seed value to use for random generation.",
268+
type=int,
269+
)
253270
return parser.parse_args()
254271

255272

@@ -267,7 +284,10 @@ def runner_main():
267284
# lot of log spam. We don't really need the warning here.
268285
warnings.simplefilter("ignore", category=FutureWarning)
269286

270-
begin_test_session(args.report)
287+
seed = args.seed or random.randint(0, 100_000_000)
288+
print(f"Running with seed {seed}.")
289+
290+
begin_test_session(args.report, seed=seed)
271291

272292
if len(args.suite) > 1:
273293
raise NotImplementedError("TODO Support multiple suites.")

0 commit comments

Comments
 (0)