Skip to content

Commit afe70c9

Browse files
fix support for 1D problem_size in caches
1 parent 74a4e34 commit afe70c9

File tree

2 files changed

+10
-16
lines changed

2 files changed

+10
-16
lines changed

kernel_tuner/util.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,9 @@ def get_kernel_string(kernel_source, params=None):
520520
kernel_string = read_file(kernel_source)
521521
elif isinstance(kernel_source, str):
522522
if looks_like_a_filename(kernel_source):
523-
kernel_string = read_file(kernel_source) or kernel_source
523+
kernel_string = read_file(kernel_source)
524+
if kernel_string is None:
525+
raise ValueError(f"{kernel_source} looks like a filename, but cannot open file")
524526
else:
525527
kernel_string = kernel_source
526528
else:
@@ -1115,6 +1117,9 @@ def compile_restrictions(
11151117
noncompiled_restrictions.append((r, [], r))
11161118
return noncompiled_restrictions + compiled_restrictions
11171119

1120+
def check_matching_problem_size(cached_problem_size, problem_size):
1121+
if not all(np.array(cached_problem_size) == np.array(problem_size)):
1122+
ValueError(f"Cannot load cache which contains results for different problem_size, cache: {cached_data['problem_size']}, requested: {kernel_options.problem_size}")
11181123

11191124
def process_cache(cache, kernel_options, tuning_options, runner):
11201125
"""Cache file for storing tuned configurations.
@@ -1185,18 +1190,7 @@ def process_cache(cache, kernel_options, tuning_options, runner):
11851190
f"Cannot load cache which contains results for different kernel (cache: {cached_data['kernel_name']}, actual: {kernel_options.kernel_name})"
11861191
)
11871192
if "problem_size" in cached_data and not callable(kernel_options.problem_size):
1188-
# if it's a single value, convert to an array
1189-
if isinstance(cached_data["problem_size"], int):
1190-
cached_data["problem_size"] = [cached_data["problem_size"]]
1191-
# if problem_size is not iterable, compare directly
1192-
if not hasattr(kernel_options.problem_size, "__iter__"):
1193-
if cached_data["problem_size"] != kernel_options.problem_size:
1194-
raise ValueError("Cannot load cache which contains results for different problem_size")
1195-
# else (problem_size is iterable)
1196-
# cache returns list, problem_size is likely a tuple. Therefore, the next check
1197-
# checks the equality of all items in the list/tuples individually
1198-
elif not all([i == j for i, j in zip(cached_data["problem_size"], kernel_options.problem_size)]):
1199-
raise ValueError("Cannot load cache which contains results for different problem_size")
1193+
check_matching_problem_size(cached_data["problem_size"], kernel_options.problem_size)
12001194
if cached_data["tune_params_keys"] != list(tuning_options.tune_params.keys()):
12011195
if all(key in tuning_options.tune_params for key in cached_data["tune_params_keys"]):
12021196
raise ValueError(

test/test_util_functions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -525,10 +525,10 @@ def gen_kernel(params):
525525

526526
def test_get_kernel_string_filename_not_found():
527527
# when the string looks like a filename, but the file does not exist
528-
# assume the string is not a filename after all
528+
# check if throws an exception
529529
bogus_filename = "filename_3456789.cu"
530-
answer = get_kernel_string(bogus_filename)
531-
assert answer == bogus_filename
530+
with pytest.raises(ValueError):
531+
get_kernel_string(bogus_filename)
532532

533533

534534
def test_looks_like_a_filename1():

0 commit comments

Comments
 (0)