Skip to content

Commit 8f3744d

Browse files
committed
Implemented passing custom search method path and options usng T1 format
1 parent b8a9902 commit 8f3744d

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

kernel_tuner/interface.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636

3737
import kernel_tuner.core as core
3838
import kernel_tuner.util as util
39-
from kernel_tuner.file_utils import get_input_file, get_t4_metadata, get_t4_results
39+
from kernel_tuner.file_utils import get_input_file, get_t4_metadata, get_t4_results, import_class_from_file
4040
from kernel_tuner.integration import get_objective_defaults
4141
from kernel_tuner.runners.sequential import SequentialRunner
4242
from kernel_tuner.runners.simulation import SimulationRunner
@@ -47,6 +47,7 @@
4747
except ImportError:
4848
torch = util.TorchPlaceHolder()
4949

50+
from kernel_tuner.strategies.wrapper import OptAlgWrapper
5051
from kernel_tuner.strategies import (
5152
basinhopping,
5253
bayes_opt,
@@ -62,7 +63,7 @@
6263
ordered_greedy_mls,
6364
pso,
6465
random_sample,
65-
simulated_annealing,
66+
simulated_annealing
6667
)
6768

6869
strategy_map = {
@@ -894,6 +895,19 @@ def tune_kernel_T1(
894895
else:
895896
raise NotImplementedError(f"Budget type in {budget} is not supported")
896897

898+
# check if the strategy is a path
899+
if "custom_search_method_path" in strategy_options:
900+
# if it is a path, import the strategy from the file
901+
opt_path: Path = Path(strategy_options["custom_search_method_path"])
902+
class_name: str = strategy
903+
assert opt_path.exists(), f"Custom search method path '{opt_path}' does not exist relative to current working directory {Path.cwd()}"
904+
optimizer_class = import_class_from_file(opt_path, class_name)
905+
budget = strategy_options.get("max_fevals", 1e12) # if not set, use a very large number to have it run out at the time limit
906+
filter_keys = ["custom_search_method_path", "max_fevals", "time_limit", "constraint_aware"]
907+
adjusted_strategy_options = {k:v for k, v in strategy_options.items() if k not in filter_keys}
908+
optimizer_instance = optimizer_class(budget=budget, **adjusted_strategy_options)
909+
strategy = OptAlgWrapper(optimizer_instance)
910+
897911
# set the cache path
898912
if cache_filepath is None and "SimulationInput" in kernelspec:
899913
cache_filepath = Path(kernelspec["SimulationInput"])

kernel_tuner/strategies/wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def tune(self, searchspace: Searchspace, runner, tuning_options):
2121
cost_func.get_bounds_x0_eps()
2222

2323
try:
24-
self.optimizer(cost_func)
24+
self.optimizer(cost_func, searchspace)
2525
except util.StopCriterionReached as e:
2626
if tuning_options.verbose:
2727
print(e)

0 commit comments

Comments
 (0)