1212from contextlib import contextmanager
1313import time
1414import json
15+ import random
1516import numpy as np
1617import platform
1718from graph_net .torch .backend .graph_compiler_backend import GraphCompilerBackend
3334}
3435
3536
37+ def set_seed (random_seed ):
38+ random .seed (random_seed )
39+ np .random .seed (random_seed )
40+ torch .manual_seed (random_seed )
41+ if torch .cuda .is_available ():
42+ torch .cuda .manual_seed (random_seed )
43+ torch .cuda .manual_seed_all (random_seed )
44+
45+
3646def load_class_from_file (
3747 args : argparse .Namespace , class_name : str , device : str
3848) -> Type [torch .nn .Module ]:
@@ -226,6 +236,7 @@ def test_single_model(args):
226236 flush = True ,
227237 )
228238
239+ runtime_seed = 1024
229240 eager_failure = False
230241 expected_out = None
231242 eager_types = []
@@ -239,6 +250,8 @@ def test_single_model(args):
239250 file = sys .stderr ,
240251 flush = True ,
241252 )
253+
254+ torch .manual_seed (runtime_seed )
242255 expected_out = eager_model_call ()
243256 if not isinstance (expected_out , tuple ):
244257 expected_out = (expected_out ,)
@@ -270,6 +283,7 @@ def test_single_model(args):
270283 else :
271284 compiled_model = compiler (model )
272285
286+ torch .manual_seed (runtime_seed )
273287 compiled_model_call = lambda : compiled_model (** input_dict )
274288 compiled_stats = measure_performance (compiled_model_call , args , compiler )
275289 print (
@@ -480,6 +494,9 @@ def is_single_model_dir(model_dir):
480494
481495def main (args ):
482496 assert os .path .isdir (args .model_path )
497+
498+ initalize_seed = 123
499+ set_seed (random_seed = initalize_seed )
483500 if is_single_model_dir (args .model_path ):
484501 test_single_model (args )
485502 else :
0 commit comments