Skip to content

Commit c65cec3

Browse files
Merge pull request #321 from KernelTuner/improve_warning_on_kernel_source_not_found
Improve warning on kernel source not found
2 parents b78d909 + 0353c40 commit c65cec3

File tree

2 files changed

+28
-16
lines changed

2 files changed

+28
-16
lines changed

kernel_tuner/util.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,8 @@ def get_kernel_string(kernel_source, params=None):
528528
kernel_string = read_file(kernel_source)
529529
elif isinstance(kernel_source, str):
530530
if looks_like_a_filename(kernel_source):
531-
kernel_string = read_file(kernel_source) or kernel_source
531+
with open(kernel_source, "r") as f:
532+
kernel_string = f.read()
532533
else:
533534
kernel_string = kernel_source
534535
else:
@@ -1123,6 +1124,10 @@ def compile_restrictions(
11231124
noncompiled_restrictions.append((r, [], r))
11241125
return noncompiled_restrictions + compiled_restrictions
11251126

1127+
def check_matching_problem_size(cached_problem_size, problem_size):
1128+
"""Check the if requested problem size matches the problem size in the cache."""
1129+
if not (np.array(cached_problem_size) == np.array(problem_size)).all():
1130+
raise ValueError(f"Cannot load cache which contains results for different problem_size, cache: {cached_problem_size}, requested: {problem_size}")
11261131

11271132
def process_cache(cache, kernel_options, tuning_options, runner):
11281133
"""Cache file for storing tuned configurations.
@@ -1193,18 +1198,7 @@ def process_cache(cache, kernel_options, tuning_options, runner):
11931198
f"Cannot load cache which contains results for different kernel (cache: {cached_data['kernel_name']}, actual: {kernel_options.kernel_name})"
11941199
)
11951200
if "problem_size" in cached_data and not callable(kernel_options.problem_size):
1196-
# if it's a single value, convert to an array
1197-
if isinstance(cached_data["problem_size"], int):
1198-
cached_data["problem_size"] = [cached_data["problem_size"]]
1199-
# if problem_size is not iterable, compare directly
1200-
if not hasattr(kernel_options.problem_size, "__iter__"):
1201-
if cached_data["problem_size"] != kernel_options.problem_size:
1202-
raise ValueError("Cannot load cache which contains results for different problem_size")
1203-
# else (problem_size is iterable)
1204-
# cache returns list, problem_size is likely a tuple. Therefore, the next check
1205-
# checks the equality of all items in the list/tuples individually
1206-
elif not all([i == j for i, j in zip(cached_data["problem_size"], kernel_options.problem_size)]):
1207-
raise ValueError("Cannot load cache which contains results for different problem_size")
1201+
check_matching_problem_size(cached_data["problem_size"], kernel_options.problem_size)
12081202
if cached_data["tune_params_keys"] != list(tuning_options.tune_params.keys()):
12091203
if all(key in tuning_options.tune_params for key in cached_data["tune_params_keys"]):
12101204
raise ValueError(

test/test_util_functions.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -541,10 +541,10 @@ def gen_kernel(params):
541541

542542
def test_get_kernel_string_filename_not_found():
543543
# when the string looks like a filename, but the file does not exist
544-
# assume the string is not a filename after all
544+
# check if throws an exception
545545
bogus_filename = "filename_3456789.cu"
546-
answer = get_kernel_string(bogus_filename)
547-
assert answer == bogus_filename
546+
with pytest.raises(FileNotFoundError):
547+
get_kernel_string(bogus_filename)
548548

549549

550550
def test_looks_like_a_filename1():
@@ -742,6 +742,24 @@ def test_parse_restrictions():
742742
assert all(param in tune_params for param in params)
743743

744744

745+
def test_check_matching_problem_size():
746+
# these should error
747+
with pytest.raises(ValueError):
748+
check_matching_problem_size(42, 1000)
749+
with pytest.raises(ValueError):
750+
check_matching_problem_size([42,1], 42)
751+
with pytest.raises(ValueError):
752+
check_matching_problem_size([42,0], 42)
753+
with pytest.raises(ValueError):
754+
check_matching_problem_size(None, 42)
755+
# these should not error
756+
check_matching_problem_size(1000, (1000,))
757+
check_matching_problem_size([1000], 1000)
758+
check_matching_problem_size(1000, 1000)
759+
check_matching_problem_size(1000, [1000])
760+
check_matching_problem_size([1000,], 1000)
761+
762+
745763
def test_convert_constraint_lambdas():
746764

747765
restrictions = [lambda p: 32 <= p["block_size_x"]*p["block_size_y"] <= 1024,

0 commit comments

Comments
 (0)