Skip to content

Commit 9c8613e

Browse files
add test
1 parent cd554d2 commit 9c8613e

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

kernel_tuner/util.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1117,8 +1117,9 @@ def compile_restrictions(
11171117
return noncompiled_restrictions + compiled_restrictions
11181118

11191119
def check_matching_problem_size(cached_problem_size, problem_size):
1120+
""" check if requested problem size matches the problem size in the cache"""
11201121
if not (np.array(cached_problem_size) == np.array(problem_size)).all():
1121-
raise ValueError(f"Cannot load cache which contains results for different problem_size, cache: {cached_data['problem_size']}, requested: {kernel_options.problem_size}")
1122+
raise ValueError(f"Cannot load cache which contains results for different problem_size, cache: {cached_problem_size}, requested: {problem_size}")
11221123

11231124
def process_cache(cache, kernel_options, tuning_options, runner):
11241125
"""Cache file for storing tuned configurations.

test/test_util_functions.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -726,6 +726,19 @@ def test_parse_restrictions():
726726
assert all(param in tune_params for param in params)
727727

728728

729+
def test_check_matching_problem_size():
730+
# these should error
731+
with pytest.raises(ValueError):
732+
check_matching_problem_size(42, 1000)
733+
with pytest.raises(ValueError):
734+
check_matching_problem_size([42,1], 42)
735+
# these should not error
736+
check_matching_problem_size(1000, (1000,))
737+
check_matching_problem_size([1000], 1000)
738+
check_matching_problem_size(1000, 1000)
739+
check_matching_problem_size(1000, [1000])
740+
741+
729742
def test_convert_constraint_lambdas():
730743

731744
restrictions = [lambda p: 32 <= p["block_size_x"]*p["block_size_y"] <= 1024,

0 commit comments

Comments
 (0)