Skip to content

Commit 1d916bd

Browse files
committed
Implemented passing whether or not to use the searchspace cache as a hyperparameter
1 parent 934be28 commit 1d916bd

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

kernel_tuner/strategies/pyatf_strategies.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111

1212
supported_searchtechniques = ["auc_bandit", "differential_evolution", "pattern_search", "round_robin", "simulated_annealing", "torczon"]
1313

14-
_options = dict(searchtechnique=(f"PyATF optimization algorithm to use, choose any from {supported_searchtechniques}", "simulated_annealing"))
14+
_options = dict(
15+
searchtechnique=(f"PyATF optimization algorithm to use, choose any from {supported_searchtechniques}", "simulated_annealing"),
16+
use_searchspace_cache=(f"Use a cached search space if available, otherwise create a new one.", False)
17+
)
1518

1619
def get_cache_checksum(d: dict):
1720
checksum=0
@@ -26,9 +29,13 @@ def tune(searchspace: Searchspace, runner, tuning_options):
2629
from pyatf.search_techniques.search_technique import SearchTechnique
2730
from pyatf.search_space import SearchSpace as pyATFSearchSpace
2831
from pyatf import TP
32+
33+
# get the search technique module name and whether to use search space caching
34+
module_name, use_searchspace_cache = common.get_options(tuning_options.strategy_options, _options)
2935
try:
30-
import dill
31-
pyatf_search_space_caching = True
36+
if use_searchspace_cache:
37+
import dill
38+
pyatf_search_space_caching = use_searchspace_cache
3239
except ImportError:
3340
from warnings import warn
3441
pyatf_search_space_caching = False
@@ -38,7 +45,6 @@ def tune(searchspace: Searchspace, runner, tuning_options):
3845
cost_func = CostFunc(searchspace, tuning_options, runner, scaling=False, snap=False, return_invalid=False)
3946

4047
# dynamically import the search technique based on the provided options
41-
module_name, = common.get_options(tuning_options.strategy_options, _options)
4248
module = import_module(f"pyatf.search_techniques.{module_name}")
4349
class_name = [d for d in dir(module) if d.lower() == module_name.replace('_','')][0]
4450
searchtechnique_class = getattr(module, class_name)

0 commit comments

Comments
 (0)