1
1
import argparse
2
+ import hashlib
2
3
import importlib
4
+ import random
3
5
import re
4
6
import time
5
7
import unittest
26
28
begin_test_session ,
27
29
complete_test_session ,
28
30
count_ops ,
31
+ get_active_test_session ,
29
32
RunSummary ,
30
33
TestCaseSummary ,
31
34
TestResult ,
40
43
}
41
44
42
45
46
+ def _get_test_seed (test_base_name : str ) -> int :
47
+ # Set the seed based on the test base name to give consistent inputs between backends. Add the
48
+ # run seed to allow for reproducible results, but still allow for run-to-run variation.
49
+ # Having a stable hash between runs and across machines is a plus (builtin python hash is not).
50
+ # Using MD5 here because it's fast and we don't actually care about cryptographic properties.
51
+ test_session = get_active_test_session ()
52
+ run_seed = (
53
+ test_session .seed
54
+ if test_session is not None
55
+ else random .randint (0 , 100_000_000 )
56
+ )
57
+
58
+ hasher = hashlib .md5 ()
59
+ data = test_base_name .encode ("utf-8" )
60
+ hasher .update (data )
61
+ # Torch doesn't like very long seeds.
62
+ return (int .from_bytes (hasher .digest (), "little" ) % 100_000_000 ) + run_seed
63
+
64
+
43
65
def run_test ( # noqa: C901
44
66
model : torch .nn .Module ,
45
67
inputs : Any ,
@@ -59,6 +81,8 @@ def run_test( # noqa: C901
59
81
error_statistics : list [ErrorStatistics ] = []
60
82
extra_stats = {}
61
83
84
+ torch .manual_seed (_get_test_seed (test_base_name ))
85
+
62
86
# Helper method to construct the summary.
63
87
def build_result (
64
88
result : TestResult , error : Exception | None = None
@@ -237,6 +261,12 @@ def parse_args():
237
261
help = "A file to write the test report to, in CSV format." ,
238
262
default = "backend_test_report.csv" ,
239
263
)
264
+ parser .add_argument (
265
+ "--seed" ,
266
+ nargs = "?" ,
267
+ help = "The numeric seed value to use for random generation." ,
268
+ type = int ,
269
+ )
240
270
return parser .parse_args ()
241
271
242
272
@@ -254,7 +284,10 @@ def runner_main():
254
284
# lot of log spam. We don't really need the warning here.
255
285
warnings .simplefilter ("ignore" , category = FutureWarning )
256
286
257
- begin_test_session (args .report )
287
+ seed = args .seed or random .randint (0 , 100_000_000 )
288
+ print (f"Running with seed { seed } ." )
289
+
290
+ begin_test_session (args .report , seed = seed )
258
291
259
292
if len (args .suite ) > 1 :
260
293
raise NotImplementedError ("TODO Support multiple suites." )
0 commit comments