Skip to content

Commit a89af66

Browse files
committed
Implemented passing keyword arguments like meta strategy and time limit to hyperparameter tuning CLI
1 parent b7e779e commit a89af66

File tree

1 file changed

+23
-5
lines changed

1 file changed

+23
-5
lines changed

kernel_tuner/hyper.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from random import randint
66
from argparse import ArgumentParser
77

8+
import numpy as np
9+
810
import kernel_tuner
911

1012

@@ -94,11 +96,26 @@ def put_if_not_present(target_dict, key, value):
9496
return list(result_unique.values()), env
9597

9698
if __name__ == "__main__":
99+
"""Main function to run the hyperparameter tuning. Run with `python hyper.py strategy_to_tune=`."""
100+
97101
parser = ArgumentParser()
98-
parser.add_argument("strategy_to_tune")
102+
parser.add_argument("strategy_to_tune", type=str, help="The strategy to tune hyperparameters for.")
103+
parser.add_argument("--meta_strategy", nargs='?', default="genetic_algorithm", type=str, help="The meta-strategy to use for hyperparameter tuning.")
104+
parser.add_argument("--max_time", nargs='?', default=60*60*24, type=int, help="The maximum time in seconds for the hyperparameter tuning.")
99105
args = parser.parse_args()
100106
strategy_to_tune = args.strategy_to_tune
101107

108+
kwargs = dict(
109+
verbose=True,
110+
quiet=False,
111+
simulation_mode=False,
112+
strategy=args.meta_strategy,
113+
cache=f"hyperparamtuning_t={strategy_to_tune}_m={args.meta_strategy}.json",
114+
strategy_options=dict(
115+
time_limit=args.max_time,
116+
)
117+
)
118+
102119
# select the hyperparameter parameters for the selected optimization algorithm
103120
restrictions = []
104121
if strategy_to_tune.lower() == "pso":
@@ -131,9 +148,10 @@ def put_if_not_present(target_dict, key, value):
131148
elif strategy_to_tune.lower() == "diff_evo":
132149
hyperparams = {
133150
'method': ["best1bin", "rand1bin", "best2bin", "rand2bin", "best1exp", "rand1exp", "best2exp", "rand2exp", "currenttobest1bin", "currenttobest1exp", "randtobest1bin", "randtobest1exp"], # best1bin
134-
'popsize': [10, 20, 30, 40, 50, 60, 70, 80, 90, 100], # 50
135-
'F': [0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0], # 1.3
136-
'CR': [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0] # 0.9
151+
'popsize': list(range(1, 100+1, 1)), # 50
152+
'popsize_times_dimensions': [True, False], # False
153+
'F': list(np.arange(0.05, 2.0+0.05, 0.05)), # 1.3
154+
'CR': list(np.arange(0.05, 1.0+0.05, 0.05)) # 0.9
137155
}
138156
elif strategy_to_tune.lower() == "basinhopping":
139157
hyperparams = {
@@ -172,6 +190,6 @@ def put_if_not_present(target_dict, key, value):
172190
raise ValueError(f"Invalid argument {strategy_to_tune=}")
173191

174192
# run the hyperparameter tuning
175-
result, env = tune_hyper_params(strategy_to_tune.lower(), hyperparams, restrictions=restrictions)
193+
result, env = tune_hyper_params(strategy_to_tune.lower(), hyperparams, restrictions=restrictions, **kwargs)
176194
print(result)
177195
print(env['best_config'])

0 commit comments

Comments
 (0)