Skip to content

Commit 4f245c6

Browse files
also store problem size in cache
1 parent bf8d4f9 commit 4f245c6

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

kernel_tuner/util.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,7 @@ def process_cache(cache, kernel_options, tuning_options, runner):
728728
729729
{ device_name: "name of device"
730730
kernel_name: "name of kernel"
731+
problem_size: (int, int, int)
731732
tune_params_keys: list
732733
tune_params:
733734
cache: {
@@ -754,6 +755,7 @@ def process_cache(cache, kernel_options, tuning_options, runner):
754755
c = OrderedDict()
755756
c["device_name"] = runner.dev.name
756757
c["kernel_name"] = kernel_options.kernel_name
758+
c["problem_size"] = kernel_options.problem_size if not callable(kernel_options.problem_size) else "callable"
757759
c["tune_params_keys"] = list(tuning_options.tune_params.keys())
758760
c["tune_params"] = tuning_options.tune_params
759761
c["cache"] = {}
@@ -780,6 +782,11 @@ def process_cache(cache, kernel_options, tuning_options, runner):
780782
raise ValueError("Cannot load cache which contains results for different device")
781783
if cached_data["kernel_name"] != kernel_options.kernel_name:
782784
raise ValueError("Cannot load cache which contains results for different kernel")
785+
if "problem_size" in cached_data and not callable(kernel_options.problem_size):
786+
# cache returns list, problem_size is likely a tuple. Therefore, the next check
787+
# checks the equality of all items in the list/tuples individually
788+
if not all([i == j for i, j in zip(cached_data["problem_size"], kernel_options.problem_size)]):
789+
raise ValueError("Cannot load cache which contains results for different problem_size")
783790
if cached_data["tune_params_keys"] != list(tuning_options.tune_params.keys()):
784791
raise ValueError("Cannot load cache which contains results obtained with different tunable parameters")
785792

test/test_util_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,7 @@ def assert_open_cachefile_is_correctly_parsed(cache):
549549
cache = get_temp_filename(suffix=".json")
550550
delete_temp_file(cache)
551551

552-
kernel_options = Options(kernel_name="test_kernel")
552+
kernel_options = Options(kernel_name="test_kernel", problem_size=(1, 2))
553553
tuning_options = Options(cache=cache, tune_params=Options(x=[1, 2, 3, 4]), simulation_mode=False)
554554
runner = Options(dev=Options(name="test_device"), simulation_mode=False)
555555

0 commit comments

Comments
 (0)