Skip to content

Commit e11cfc5

Browse files
not inserting constexpr int loop unroll factor when 0
1 parent 90fdbda commit e11cfc5

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

kernel_tuner/util.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -606,9 +606,10 @@ def prepare_kernel_string(kernel_name, kernel_string, params, grid, threads, blo
606606
# pragma unroll loop_unroll_factor, loop_unroll_factor should be a constant integer expression
607607
# in OpenCL this isn't the case and we can just insert "#define loop_unroll_factor N"
608608
# using 0 to disable specifying a loop unrolling factor for this loop
609-
kernel_prefix += f"constexpr int {k} = {v};\n"
610609
if v == "0":
611610
kernel_string = re.sub(r"\n\s*#pragma\s+unroll\s+" + k, "\n", kernel_string) # + r"[^\S]*"
611+
else:
612+
kernel_prefix += f"constexpr int {k} = {v};\n"
612613
else:
613614
kernel_prefix += f"#define {k} {v}\n"
614615

test/test_util_functions.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,28 @@ def test_prepare_kernel_string():
198198
prepare_kernel_string("this", kernel, params, grid, threads, block_size_names, "", invalid_defines)
199199

200200

201+
def test_prepare_kernel_string_partial_loop_unrolling():
202+
203+
kernel = """this is a weird kernel(what * language, is this, anyway* C) {
204+
#pragma unroll loop_unroll_factor_monkey
205+
for monkey in the forest {
206+
eat(banana);
207+
}
208+
}"""
209+
threads = (1, 2, 3)
210+
grid = (4, 5, 6)
211+
params = dict()
212+
params["loop_unroll_factor_monkey"] = 8
213+
214+
_, output = prepare_kernel_string("this", kernel, params, grid, threads, block_size_names, "CUDA", None)
215+
assert "constexpr int loop_unroll_factor_monkey = 8;" in output
216+
217+
params["loop_unroll_factor_monkey"] = 0
218+
_, output = prepare_kernel_string("this", kernel, params, grid, threads, block_size_names, "CUDA", None)
219+
assert not "constexpr int loop_unroll_factor_monkey" in output
220+
assert not "#pragma unroll loop_unroll_factor_monkey" in output
221+
222+
201223

202224
def test_replace_param_occurrences():
203225
kernel = "this is a weird kernel"

0 commit comments

Comments
 (0)