Skip to content

Commit 24177bd

Browse files
committed
add fixes for sweep
Signed-off-by: Jack Luar <[email protected]>
1 parent c4ce775 commit 24177bd

File tree

2 files changed

+26
-18
lines changed

2 files changed

+26
-18
lines changed

tools/AutoTuner/src/autotuner/distributed.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -422,17 +422,17 @@ def parse_arguments():
422422

423423
args = parser.parse_args()
424424
if args.mode == "tune":
425-
args.algorithm = args.algorithm.lower()
425+
args.tune.algorithm = args.tune.algorithm.lower()
426426
# Validation of arguments
427-
if args.eval == "ppa-improv" and args.reference is None:
427+
if args.tune.eval == "ppa-improv" and args.tune.reference is None:
428428
print(
429429
'[ERROR TUN-0006] The argument "--eval ppa-improv"'
430430
' requires that "--reference <FILE>" is also given.'
431431
)
432432
sys.exit(7)
433433

434434
# Check for experiment name and resume flag.
435-
if args.resume and args.experiment == "test":
435+
if args.tune.resume and args.experiment == "test":
436436
print(
437437
'[ERROR TUN-0031] The flag "--resume"'
438438
' requires that "--experiment NAME" is also given.'
@@ -587,49 +587,57 @@ def main():
587587

588588
# Read config and original files before handling where to run in case we
589589
# need to upload the files.
590-
config_dict, SDC_ORIGINAL, FR_ORIGINAL = read_config(
591-
os.path.abspath(args.config), args.mode, getattr(args, "algorithm", None)
592-
)
590+
if args.mode == "tune":
591+
config_dict, SDC_ORIGINAL, FR_ORIGINAL = read_config(
592+
file_name=os.path.abspath(args.tune.config),
593+
mode=args.mode,
594+
algorithm=args.tune.algorithm,
595+
)
596+
else:
597+
config_dict, SDC_ORIGINAL, FR_ORIGINAL = read_config(
598+
file_name=os.path.abspath(args.sweep.config),
599+
mode=args.mode,
600+
)
593601

594602
LOCAL_DIR, ORFS_FLOW_DIR, INSTALL_PATH = prepare_ray_server(args)
595603

596604
if args.mode == "tune":
597605
best_params = set_best_params(args.platform, args.design)
598606
search_algo = set_algorithm(
599-
args.algorithm,
607+
args.tune.algorithm,
600608
args.experiment,
601609
best_params,
602-
args.seed,
603-
args.perturbation,
610+
args.tune.seed,
611+
args.tune.perturbation,
604612
args.jobs,
605613
config_dict,
606614
)
607-
TrainClass = set_training_class(args.eval)
615+
TrainClass = set_training_class(args.tune.eval)
608616
# PPAImprov requires a reference file to compute training scores.
609-
if args.eval == "ppa-improv":
610-
reference = read_metrics(args.reference)
617+
if args.tune.eval == "ppa-improv":
618+
reference = read_metrics(args.tune.reference)
611619

612620
tune_args = dict(
613621
name=args.experiment,
614622
metric=METRIC,
615623
mode="min",
616-
num_samples=args.samples,
624+
num_samples=args.tune.samples,
617625
fail_fast=False,
618626
storage_path=LOCAL_DIR,
619-
resume=args.resume,
620-
stop={"training_iteration": args.iterations},
627+
resume=args.tune.resume,
628+
stop={"training_iteration": args.tune.iterations},
621629
resources_per_trial={"cpu": os.cpu_count() / args.jobs},
622630
log_to_file=["trail-out.log", "trail-err.log"],
623631
trial_name_creator=lambda x: f"variant-{x.trainable_name}-{x.trial_id}-ray",
624632
trial_dirname_creator=lambda x: f"variant-{x.trainable_name}-{x.trial_id}-ray",
625633
)
626-
if args.algorithm == "pbt":
634+
if args.tune.algorithm == "pbt":
627635
os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = str(args.jobs)
628636
tune_args["scheduler"] = search_algo
629637
else:
630638
tune_args["search_alg"] = search_algo
631639
tune_args["scheduler"] = AsyncHyperBandScheduler()
632-
if args.algorithm != "ax":
640+
if args.tune.algorithm != "ax":
633641
tune_args["config"] = config_dict
634642
analysis = tune.run(TrainClass, **tune_args)
635643

tools/AutoTuner/src/autotuner/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def read_metrics(file_name):
410410
return ret
411411

412412

413-
def read_config(file_name, mode, algorithm):
413+
def read_config(file_name, mode, algorithm = None):
414414
"""
415415
Please consider inclusive, exclusive
416416
Most type uses [min, max)

0 commit comments

Comments
 (0)