Skip to content

Commit 20ea709

Browse files
committed
Implemented the pyatf_strategies, which enables using the pyATF strategies in Kernel Tuner
1 parent 268bf67 commit 20ea709

File tree

5 files changed

+317
-51
lines changed

5 files changed

+317
-51
lines changed

kernel_tuner/interface.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747
except ImportError:
4848
torch = util.TorchPlaceHolder()
4949

50-
from kernel_tuner.strategies.wrapper import OptAlgWrapper
5150
from kernel_tuner.strategies import (
5251
basinhopping,
5352
bayes_opt,
@@ -62,9 +61,11 @@
6261
mls,
6362
ordered_greedy_mls,
6463
pso,
64+
pyatf_strategies,
6565
random_sample,
66-
simulated_annealing
66+
simulated_annealing,
6767
)
68+
from kernel_tuner.strategies.wrapper import OptAlgWrapper
6869

6970
strategy_map = {
7071
"brute_force": brute_force,
@@ -81,7 +82,8 @@
8182
"pso": pso,
8283
"simulated_annealing": simulated_annealing,
8384
"firefly_algorithm": firefly_algorithm,
84-
"bayes_opt": bayes_opt
85+
"bayes_opt": bayes_opt,
86+
"pyatf_strategies": pyatf_strategies,
8587
}
8688

8789

@@ -629,6 +631,7 @@ def tune_kernel(
629631
logging.debug("tuning_options: %s", util.get_config_string(tuning_options))
630632
logging.debug("device_options: %s", util.get_config_string(device_options))
631633

634+
strategy_string = strategy
632635
if strategy:
633636
if strategy in strategy_map:
634637
strategy = strategy_map[strategy]
@@ -861,10 +864,9 @@ def tune_kernel_T1(
861864
strategy: str=None,
862865
strategy_options: dict={},
863866
) -> tuple:
864-
"""
865-
Call the tune function with a T1 input file.
867+
"""Call the tune function with a T1 input file.
866868
867-
The device, strategy and strategy_options can be overridden by passing a strategy name and options, otherwise the input file specification is used.
869+
The device, strategy and strategy_options can be overridden by passing a strategy name and options, otherwise the input file specification is used.
868870
"""
869871
inputs = get_input_file(input_filepath)
870872
kernelspec: dict = inputs["KernelSpecification"]

kernel_tuner/searchspace.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def __init__(
7676
framework_l = framework.lower()
7777
restrictions = restrictions if restrictions is not None else []
7878
self.tune_params = tune_params
79+
self.tune_params_pyatf = None
7980
self._tensorspace = None
8081
self.tensor_dtype = torch.float32 if torch_available else None
8182
self.tensor_device = torch.device("cpu") if torch_available else None
@@ -376,10 +377,13 @@ def get_params():
376377
constraint = res
377378
params.append(TP(key, vals, constraint, constraint_source))
378379
return params
380+
381+
# set data
382+
self.tune_params_pyatf = get_params()
379383

380384
# tune
381385
_, _, tuning_data = (
382-
Tuner().verbosity(0).tuning_parameters(*get_params()).search_technique(Exhaustive()).tune(costfunc)
386+
Tuner().verbosity(0).tuning_parameters(*self.tune_params_pyatf).search_technique(Exhaustive()).tune(costfunc)
383387
)
384388

385389
# transform the result into a list of parameter configurations for validation

kernel_tuner/strategies/common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,6 @@ def scale_from_params(params, tune_params, eps):
356356

357357
def unscale_and_snap_to_nearest_valid(x, params, searchspace, eps):
358358
"""Helper func to snap to the nearest valid configuration"""
359-
360359
# params is nearest unscaled point, but is not valid
361360
neighbors = get_neighbors(params, searchspace)
362361

kernel_tuner/strategies/ptatf_strategies.py

Lines changed: 0 additions & 43 deletions
This file was deleted.

0 commit comments

Comments
 (0)