Skip to content

Commit 1517565

Browse files
add support for user-defined optimization algorithms
1 parent 3f32fed commit 1517565

File tree

2 files changed

+53
-23
lines changed

2 files changed

+53
-23
lines changed
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#!/usr/bin/env python
2+
"""This is the minimal example from the README"""
3+
4+
import numpy
5+
import kernel_tuner
6+
from kernel_tuner import tune_kernel
7+
from kernel_tuner.file_utils import store_output_file, store_metadata_file
8+
9+
def tune():
10+
11+
kernel_string = """
12+
__global__ void vector_add(float *c, float *a, float *b, int n) {
13+
int i = blockIdx.x * block_size_x + threadIdx.x;
14+
if (i<n) {
15+
c[i] = a[i] + b[i];
16+
}
17+
}
18+
"""
19+
20+
size = 10000000
21+
22+
a = numpy.random.randn(size).astype(numpy.float32)
23+
b = numpy.random.randn(size).astype(numpy.float32)
24+
c = numpy.zeros_like(b)
25+
n = numpy.int32(size)
26+
27+
args = [c, a, b, n]
28+
29+
tune_params = dict()
30+
tune_params["block_size_x"] = [128+64*i for i in range(15)]
31+
32+
results, env = tune_kernel("vector_add", kernel_string, size, args, tune_params, strategy=kernel_tuner.strategies.minimize, verbose=True)
33+
34+
# Store the tuning results in an output file
35+
store_output_file("vector_add.json", results, tune_params)
36+
37+
# Store the metadata of this run
38+
store_metadata_file("vector_add-metadata.json")
39+
40+
return results
41+
42+
43+
if __name__ == "__main__":
44+
tune()

kernel_tuner/interface.py

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -621,29 +621,15 @@ def tune_kernel(
621621
if strategy in strategy_map:
622622
strategy = strategy_map[strategy]
623623
else:
624-
raise ValueError(f"Unkown strategy {strategy}, must be one of: {', '.join(list(strategy_map.keys()))}")
625-
626-
# make strategy_options into an Options object
627-
if tuning_options.strategy_options:
628-
if not isinstance(strategy_options, Options):
629-
tuning_options.strategy_options = Options(strategy_options)
630-
631-
# select strategy based on user options
632-
if "fraction" in tuning_options.strategy_options and not tuning_options.strategy == "random_sample":
633-
raise ValueError(
634-
'It is not possible to use fraction in combination with strategies other than "random_sample". '
635-
'Please set strategy="random_sample", when using "fraction" in strategy_options'
636-
)
637-
638-
# check if method is supported by the selected strategy
639-
if "method" in tuning_options.strategy_options:
640-
method = tuning_options.strategy_options.method
641-
if method not in strategy.supported_methods:
642-
raise ValueError("Method %s is not supported for strategy %s" % (method, tuning_options.strategy))
643-
644-
# if no strategy_options dict has been passed, create empty dictionary
645-
else:
646-
tuning_options.strategy_options = Options({})
624+
# check for user-defined strategy
625+
if hasattr(strategy, "tune") and callable(strategy.tune):
626+
# user-defined strategy
627+
pass
628+
else:
629+
raise ValueError(f"Unkown strategy {strategy}, must be one of: {', '.join(list(strategy_map.keys()))}")
630+
631+
# ensure strategy_options is an Options object
632+
tuning_options.strategy_options = Options(strategy_options or {})
647633

648634
# if no strategy selected
649635
else:

0 commit comments

Comments
 (0)