Skip to content

Commit b90f5e4

Browse files
Merge pull request #306 from KernelTuner/fix-issue-264
Fix error 'grid divisor cannot be integer' (issue #264)
2 parents 74a4e34 + 2437bfc commit b90f5e4

File tree

2 files changed

+33
-9
lines changed

2 files changed

+33
-9
lines changed

kernel_tuner/util.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -442,16 +442,24 @@ def compact_number(v):
442442
def get_grid_dimensions(current_problem_size, params, grid_div, block_size_names):
443443
"""Compute grid dims based on problem sizes and listed grid divisors."""
444444

445-
def get_dimension_divisor(divisor_list, default, params):
446-
if divisor_list is None:
447-
if default in params:
448-
divisor_list = [default]
449-
else:
450-
return 1
451-
if callable(divisor_list):
452-
return divisor_list(params)
445+
def get_dimension_divisor(divisor, default, params):
446+
divisor_num = 1
447+
448+
if divisor is None:
449+
divisor_num = params.get(default, 1)
450+
elif isinstance(divisor, int):
451+
divisor_num = divisor
452+
elif callable(divisor):
453+
divisor_num = divisor(params)
454+
elif isinstance(divisor, str):
455+
divisor_num = int(eval(replace_param_occurrences(divisor, params)))
456+
elif np.iterable(divisor):
457+
for div in divisor:
458+
divisor_num *= get_dimension_divisor(div, 1, params)
453459
else:
454-
return np.prod([int(eval(replace_param_occurrences(s, params))) for s in divisor_list])
460+
raise ValueError("Error: unrecognized type in grid divisor list, should be any of int, str, callable, or iterable")
461+
462+
return divisor_num
455463

456464
divisors = [get_dimension_divisor(d, block_size_names[i], params) for i, d in enumerate(grid_div)]
457465
return tuple(int(np.ceil(float(current_problem_size[i]) / float(d))) for i, d in enumerate(divisors))

test/test_util_functions.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,22 @@ def test_get_grid_dimensions1():
5959
assert grid[1] == 25
6060
assert grid[2] == 1
6161

62+
grid = get_grid_dimensions(
63+
problem_size, params, ("41", 37, None), block_size_names
64+
)
65+
66+
assert grid[0] == 25
67+
assert grid[1] == 28
68+
assert grid[2] == 1
69+
70+
grid = get_grid_dimensions(
71+
problem_size, params, (None, [2, "block_y"], None), block_size_names
72+
)
73+
74+
assert grid[0] == 1024
75+
assert grid[1] == 14
76+
assert grid[2] == 1
77+
6278

6379
def test_get_grid_dimensions2():
6480
problem_size = (1024, 1024, 1)

0 commit comments

Comments
 (0)