|
36 | 36 |
|
37 | 37 | import kernel_tuner.core as core
|
38 | 38 | import kernel_tuner.util as util
|
39 |
| -from kernel_tuner.file_utils import get_input_file, get_t4_metadata, get_t4_results |
| 39 | +from kernel_tuner.file_utils import get_input_file, get_t4_metadata, get_t4_results, import_class_from_file |
40 | 40 | from kernel_tuner.integration import get_objective_defaults
|
41 | 41 | from kernel_tuner.runners.sequential import SequentialRunner
|
42 | 42 | from kernel_tuner.runners.simulation import SimulationRunner
|
|
47 | 47 | except ImportError:
|
48 | 48 | torch = util.TorchPlaceHolder()
|
49 | 49 |
|
| 50 | +from kernel_tuner.strategies.wrapper import OptAlgWrapper |
50 | 51 | from kernel_tuner.strategies import (
|
51 | 52 | basinhopping,
|
52 | 53 | bayes_opt,
|
|
62 | 63 | ordered_greedy_mls,
|
63 | 64 | pso,
|
64 | 65 | random_sample,
|
65 |
| - simulated_annealing, |
| 66 | + simulated_annealing |
66 | 67 | )
|
67 | 68 |
|
68 | 69 | strategy_map = {
|
@@ -894,6 +895,19 @@ def tune_kernel_T1(
|
894 | 895 | else:
|
895 | 896 | raise NotImplementedError(f"Budget type in {budget} is not supported")
|
896 | 897 |
|
| 898 | + # check if the strategy is a path |
| 899 | + if "custom_search_method_path" in strategy_options: |
| 900 | + # if it is a path, import the strategy from the file |
| 901 | + opt_path: Path = Path(strategy_options["custom_search_method_path"]) |
| 902 | + class_name: str = strategy |
| 903 | + assert opt_path.exists(), f"Custom search method path '{opt_path}' does not exist relative to current working directory {Path.cwd()}" |
| 904 | + optimizer_class = import_class_from_file(opt_path, class_name) |
| 905 | + budget = strategy_options.get("max_fevals", 1e12) # if not set, use a very large number to have it run out at the time limit |
| 906 | + filter_keys = ["custom_search_method_path", "max_fevals", "time_limit", "constraint_aware"] |
| 907 | + adjusted_strategy_options = {k:v for k, v in strategy_options.items() if k not in filter_keys} |
| 908 | + optimizer_instance = optimizer_class(budget=budget, **adjusted_strategy_options) |
| 909 | + strategy = OptAlgWrapper(optimizer_instance) |
| 910 | + |
897 | 911 | # set the cache path
|
898 | 912 | if cache_filepath is None and "SimulationInput" in kernelspec:
|
899 | 913 | cache_filepath = Path(kernelspec["SimulationInput"])
|
|
0 commit comments