4040from uuid import uuid4 as uuid
4141
4242import numpy as np
43+ import torch
4344
4445import ray
4546from ray import tune
@@ -805,7 +806,11 @@ def parse_arguments():
805806 help = "Perturbation interval for PopulationBasedTraining." ,
806807 )
807808 tune_parser .add_argument (
808- "--seed" , type = int , metavar = "<int>" , default = 42 , help = "Random seed."
809+ "--seed" ,
810+ type = int ,
811+ metavar = "<int>" ,
812+ default = 42 ,
813+ help = "Random seed. (0 means no seed.)" ,
809814 )
810815
811816 # Workload
@@ -870,10 +875,29 @@ def set_algorithm(experiment_name, config):
870875 """
871876 Configure search algorithm.
872877 """
878+ # Pre-set seed if user sets seed to 0
879+ if args .seed == 0 :
880+ print (
881+ "Warning: you have chosen not to set a seed. Do you wish to continue? (y/n)"
882+ )
883+ if input ().lower () != "y" :
884+ sys .exit (0 )
885+ args .seed = None
886+ else :
887+ torch .manual_seed (args .seed )
888+ np .random .seed (args .seed )
889+ random .seed (args .seed )
890+
873891 if args .algorithm == "hyperopt" :
874- algorithm = HyperOptSearch (points_to_evaluate = best_params )
892+ algorithm = HyperOptSearch (
893+ points_to_evaluate = best_params ,
894+ random_state_seed = args .seed ,
895+ )
875896 elif args .algorithm == "ax" :
876- ax_client = AxClient (enforce_sequential_optimization = False )
897+ ax_client = AxClient (
898+ enforce_sequential_optimization = False ,
899+ random_seed = args .seed ,
900+ )
877901 AxClientMetric = namedtuple ("AxClientMetric" , "minimize" )
878902 ax_client .create_experiment (
879903 name = experiment_name ,
@@ -884,16 +908,23 @@ def set_algorithm(experiment_name, config):
884908 elif args .algorithm == "optuna" :
885909 algorithm = OptunaSearch (points_to_evaluate = best_params , seed = args .seed )
886910 elif args .algorithm == "pbt" :
911+ print ("Warning: PBT does not support seed values. args.seed will be ignored." )
887912 algorithm = PopulationBasedTraining (
888913 time_attr = "training_iteration" ,
889914 perturbation_interval = args .perturbation ,
890915 hyperparam_mutations = config ,
891916 synch = True ,
892917 )
893918 elif args .algorithm == "random" :
894- algorithm = BasicVariantGenerator (max_concurrent = args .jobs )
919+ algorithm = BasicVariantGenerator (
920+ max_concurrent = args .jobs ,
921+ random_state = args .seed ,
922+ )
923+
924+ # A wrapper algorithm for limiting the number of concurrent trials.
895925 if args .algorithm not in ["random" , "pbt" ]:
896926 algorithm = ConcurrencyLimiter (algorithm , max_concurrent = args .jobs )
927+
897928 return algorithm
898929
899930
0 commit comments