@@ -216,11 +216,13 @@ def evaluate(self, metrics):
216216 error = "ERR" in metrics .values () or "ERR" in reference .values ()
217217 not_found = "N/A" in metrics .values () or "N/A" in reference .values ()
218218 if error or not_found :
219- return ERROR_METRIC
219+ return ( ERROR_METRIC , "-" , "-" )
220220 ppa = self .get_ppa (metrics )
221221 gamma = ppa / 10
222222 score = ppa * (self .step_ / 100 ) ** (- 1 ) + (gamma * metrics ["num_drc" ])
223- return score
223+ effective_clk_period = metrics ["clk_period" ] - metrics ["worst_slack" ]
224+ num_drc = metrics ["num_drc" ]
225+ return (score , effective_clk_period , num_drc )
224226
225227
226228def parse_arguments ():
@@ -464,32 +466,33 @@ def parse_arguments():
464466 return args
465467
466468
467- def set_algorithm (experiment_name , config ):
469+ def set_algorithm (algorithm_name , experiment_name , best_params , seed , perturbation ,
470+ jobs , config ):
468471 """
469472 Configure search algorithm.
470473 """
471474 # Pre-set seed if user sets seed to 0
472- if args . seed == 0 :
475+ if seed == 0 :
473476 print (
474477 "Warning: you have chosen not to set a seed. Do you wish to continue? (y/n)"
475478 )
476479 if input ().lower () != "y" :
477480 sys .exit (0 )
478- args . seed = None
481+ seed = None
479482 else :
480- torch .manual_seed (args . seed )
481- np .random .seed (args . seed )
482- random .seed (args . seed )
483+ torch .manual_seed (seed )
484+ np .random .seed (seed )
485+ random .seed (seed )
483486
484- if args . algorithm == "hyperopt" :
487+ if algorithm_name == "hyperopt" :
485488 algorithm = HyperOptSearch (
486489 points_to_evaluate = best_params ,
487- random_state_seed = args . seed ,
490+ random_state_seed = seed ,
488491 )
489- elif args . algorithm == "ax" :
492+ elif algorithm_name == "ax" :
490493 ax_client = AxClient (
491494 enforce_sequential_optimization = False ,
492- random_seed = args . seed ,
495+ random_seed = seed ,
493496 )
494497 AxClientMetric = namedtuple ("AxClientMetric" , "minimize" )
495498 ax_client .create_experiment (
@@ -498,25 +501,25 @@ def set_algorithm(experiment_name, config):
498501 objectives = {METRIC : AxClientMetric (minimize = True )},
499502 )
500503 algorithm = AxSearch (ax_client = ax_client , points_to_evaluate = best_params )
501- elif args . algorithm == "optuna" :
502- algorithm = OptunaSearch (points_to_evaluate = best_params , seed = args . seed )
503- elif args . algorithm == "pbt" :
504- print ("Warning: PBT does not support seed values. args. seed will be ignored." )
504+ elif algorithm_name == "optuna" :
505+ algorithm = OptunaSearch (points_to_evaluate = best_params , seed = seed )
506+ elif algorithm_name == "pbt" :
507+ print ("Warning: PBT does not support seed values. seed will be ignored." )
505508 algorithm = PopulationBasedTraining (
506509 time_attr = "training_iteration" ,
507- perturbation_interval = args . perturbation ,
510+ perturbation_interval = perturbation ,
508511 hyperparam_mutations = config ,
509512 synch = True ,
510513 )
511- elif args . algorithm == "random" :
514+ elif algorithm_name == "random" :
512515 algorithm = BasicVariantGenerator (
513- max_concurrent = args . jobs ,
514- random_state = args . seed ,
516+ max_concurrent = jobs ,
517+ random_state = seed ,
515518 )
516519
517520 # A wrapper algorithm for limiting the number of concurrent trials.
518- if args . algorithm not in ["random" , "pbt" ]:
519- algorithm = ConcurrencyLimiter (algorithm , max_concurrent = args . jobs )
521+ if algorithm_name not in ["random" , "pbt" ]:
522+ algorithm = ConcurrencyLimiter (algorithm , max_concurrent = jobs )
520523
521524 return algorithm
522525
@@ -607,7 +610,9 @@ def main():
607610
608611 if args .mode == "tune" :
609612 best_params = set_best_params (args .platform , args .design )
610- search_algo = set_algorithm (args .experiment , config_dict )
613+ search_algo = set_algorithm (args .algorithm , args .experiment ,
614+ best_params , args .seed , args .perturbation ,
615+ args .jobs , config_dict )
611616 TrainClass = set_training_class (args .eval )
612617 # PPAImprov requires a reference file to compute training scores.
613618 if args .eval == "ppa-improv" :
0 commit comments