11import argparse
22import hashlib
33import importlib
4+ import random
45import re
56import time
67import unittest
2728 begin_test_session ,
2829 complete_test_session ,
2930 count_ops ,
31+ get_active_test_session ,
3032 RunSummary ,
3133 TestCaseSummary ,
3234 TestResult ,
4244
4345
4446def _get_test_seed (test_base_name : str ) -> int :
45- # Set the seed based on the test base name to give consistent inputs between runs and backends.
46- # Having a stable hash between runs and across machines is a plus (builtin python hash is not).
47+ # Set the seed based on the test base name to give consistent inputs between backends. Add the
48+ # run seed to allow for reproducible results, but still allow for run-to-run variation.
49+ # Having a stable hash between runs and across machines is a plus (builtin python hash is not).
4750 # Using MD5 here because it's fast and we don't actually care about cryptographic properties.
51+ test_session = get_active_test_session ()
52+ run_seed = (
53+ test_session .seed
54+ if test_session is not None
55+ else random .randint (0 , 100_000_000 )
56+ )
57+
4858 hasher = hashlib .md5 ()
4959 data = test_base_name .encode ("utf-8" )
5060 hasher .update (data )
5161 # Torch doesn't like very long seeds.
52- return int .from_bytes (hasher .digest (), "little" ) % 100_000_000
62+ return (int .from_bytes (hasher .digest (), "little" ) % 100_000_000 ) + run_seed
63+
5364
5465def run_test ( # noqa: C901
5566 model : torch .nn .Module ,
@@ -250,6 +261,12 @@ def parse_args():
250261 help = "A file to write the test report to, in CSV format." ,
251262 default = "backend_test_report.csv" ,
252263 )
264+ parser .add_argument (
265+ "--seed" ,
266+ nargs = "?" ,
267+ help = "The numeric seed value to use for random generation." ,
268+ type = int ,
269+ )
253270 return parser .parse_args ()
254271
255272
@@ -267,7 +284,10 @@ def runner_main():
267284 # lot of log spam. We don't really need the warning here.
268285 warnings .simplefilter ("ignore" , category = FutureWarning )
269286
270- begin_test_session (args .report )
287+ seed = args .seed or random .randint (0 , 100_000_000 )
288+ print (f"Running with seed { seed } ." )
289+
290+ begin_test_session (args .report , seed = seed )
271291
272292 if len (args .suite ) > 1 :
273293 raise NotImplementedError ("TODO Support multiple suites." )
0 commit comments