@@ -422,17 +422,17 @@ def parse_arguments():
422422
423423 args = parser .parse_args ()
424424 if args .mode == "tune" :
425- args .algorithm = args .algorithm .lower ()
425+ args .tune . algorithm = args . tune .algorithm .lower ()
426426 # Validation of arguments
427- if args .eval == "ppa-improv" and args .reference is None :
427+ if args .tune . eval == "ppa-improv" and args . tune .reference is None :
428428 print (
429429 '[ERROR TUN-0006] The argument "--eval ppa-improv"'
430430 ' requires that "--reference <FILE>" is also given.'
431431 )
432432 sys .exit (7 )
433433
434434 # Check for experiment name and resume flag.
435- if args .resume and args .experiment == "test" :
435+ if args .tune . resume and args .experiment == "test" :
436436 print (
437437 '[ERROR TUN-0031] The flag "--resume"'
438438 ' requires that "--experiment NAME" is also given.'
@@ -587,49 +587,57 @@ def main():
587587
588588 # Read config and original files before handling where to run in case we
589589 # need to upload the files.
590- config_dict , SDC_ORIGINAL , FR_ORIGINAL = read_config (
591- os .path .abspath (args .config ), args .mode , getattr (args , "algorithm" , None )
592- )
590+ if args .mode == "tune" :
591+ config_dict , SDC_ORIGINAL , FR_ORIGINAL = read_config (
592+ file_name = os .path .abspath (args .tune .config ),
593+ mode = args .mode ,
594+ algorithm = args .tune .algorithm ,
595+ )
596+ else :
597+ config_dict , SDC_ORIGINAL , FR_ORIGINAL = read_config (
598+ file_name = os .path .abspath (args .sweep .config ),
599+ mode = args .mode ,
600+ )
593601
594602 LOCAL_DIR , ORFS_FLOW_DIR , INSTALL_PATH = prepare_ray_server (args )
595603
596604 if args .mode == "tune" :
597605 best_params = set_best_params (args .platform , args .design )
598606 search_algo = set_algorithm (
599- args .algorithm ,
607+ args .tune . algorithm ,
600608 args .experiment ,
601609 best_params ,
602- args .seed ,
603- args .perturbation ,
610+ args .tune . seed ,
611+ args .tune . perturbation ,
604612 args .jobs ,
605613 config_dict ,
606614 )
607- TrainClass = set_training_class (args .eval )
615+ TrainClass = set_training_class (args .tune . eval )
608616 # PPAImprov requires a reference file to compute training scores.
609- if args .eval == "ppa-improv" :
610- reference = read_metrics (args .reference )
617+ if args .tune . eval == "ppa-improv" :
618+ reference = read_metrics (args .tune . reference )
611619
612620 tune_args = dict (
613621 name = args .experiment ,
614622 metric = METRIC ,
615623 mode = "min" ,
616- num_samples = args .samples ,
624+ num_samples = args .tune . samples ,
617625 fail_fast = False ,
618626 storage_path = LOCAL_DIR ,
619- resume = args .resume ,
620- stop = {"training_iteration" : args .iterations },
627+ resume = args .tune . resume ,
628+ stop = {"training_iteration" : args .tune . iterations },
621629 resources_per_trial = {"cpu" : os .cpu_count () / args .jobs },
622630 log_to_file = ["trail-out.log" , "trail-err.log" ],
623631 trial_name_creator = lambda x : f"variant-{ x .trainable_name } -{ x .trial_id } -ray" ,
624632 trial_dirname_creator = lambda x : f"variant-{ x .trainable_name } -{ x .trial_id } -ray" ,
625633 )
626- if args .algorithm == "pbt" :
634+ if args .tune . algorithm == "pbt" :
627635 os .environ ["TUNE_MAX_PENDING_TRIALS_PG" ] = str (args .jobs )
628636 tune_args ["scheduler" ] = search_algo
629637 else :
630638 tune_args ["search_alg" ] = search_algo
631639 tune_args ["scheduler" ] = AsyncHyperBandScheduler ()
632- if args .algorithm != "ax" :
640+ if args .tune . algorithm != "ax" :
633641 tune_args ["config" ] = config_dict
634642 analysis = tune .run (TrainClass , ** tune_args )
635643
0 commit comments