Skip to content

[Backend Tester] Seed based on test name #13313

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: gh/GregoryComer/123/head
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions backends/test/suite/runner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import hashlib
import importlib
import re
import time
Expand Down Expand Up @@ -40,6 +41,16 @@
}


def _get_test_seed(test_base_name: str) -> int:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not set a new, global seed every run? And print it somewhere to reproduce. Hardcoding seed ==> we will test with same random numbers every time, not sure if that's what we want.

# Set the seed based on the test base name to give consistent inputs between runs and backends.
# Having a stable hash between runs and across machines is a plus (builtin python hash is not).
# Using MD5 here because it's fast and we don't actually care about cryptographic properties.
hasher = hashlib.md5()
data = test_base_name.encode("utf-8")
hasher.update(data)
# Torch doesn't like very long seeds.
return int.from_bytes(hasher.digest(), "little") % 100_000_000

def run_test( # noqa: C901
model: torch.nn.Module,
inputs: Any,
Expand All @@ -59,6 +70,8 @@ def run_test( # noqa: C901
error_statistics: list[ErrorStatistics] = []
extra_stats = {}

torch.manual_seed(_get_test_seed(test_base_name))

# Helper method to construct the summary.
def build_result(
result: TestResult, error: Exception | None = None
Expand Down
Loading