Skip to content

Commit 685aed3

Browse files
committed
added test_smem_args --> passed
1 parent f0ab5ef commit 685aed3

File tree

1 file changed

+39
-2
lines changed

1 file changed

+39
-2
lines changed

test/test_hip_functions.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import numpy as np
22
import ctypes
33
from .context import skip_if_no_pyhip
4+
from collections import OrderedDict
45

56
import pytest
67
import kernel_tuner
8+
from kernel_tuner import tune_kernel
79
from kernel_tuner.backends import hip as kt_hip
810
from kernel_tuner.core import KernelSource, KernelInstance
911

@@ -13,6 +15,29 @@
1315
except ImportError:
1416
pass
1517

18+
@pytest.fixture
19+
def env():
20+
kernel_string = """
21+
extern "C" __global__ void vector_add(float *c, float *a, float *b, int n) {
22+
int i = blockIdx.x * block_size_x + threadIdx.x;
23+
if (i<n) {
24+
c[i] = a[i] + b[i];
25+
}
26+
}
27+
"""
28+
29+
size = 100
30+
a = np.random.randn(size).astype(np.float32)
31+
b = np.random.randn(size).astype(np.float32)
32+
c = np.zeros_like(b)
33+
n = np.int32(size)
34+
35+
args = [c, a, b, n]
36+
tune_params = OrderedDict()
37+
tune_params["block_size_x"] = [128 + 64 * i for i in range(15)]
38+
39+
return ["vector_add", kernel_string, size, args, tune_params]
40+
1641
@skip_if_no_pyhip
1742
def test_ready_argument_list():
1843

@@ -125,6 +150,18 @@ def test_copy_constant_memory_args():
125150

126151
assert(my_constant_data == output).all()
127152

128-
def dummy_func(a, b, block=0, grid=0, stream=None, shared=0, texrefs=None):
129-
pass
153+
@skip_if_no_pyhip
154+
def test_smem_args(env):
155+
result, _ = tune_kernel(*env,
156+
smem_args=dict(size="block_size_x*4"),
157+
verbose=True, lang="HIP")
158+
tune_params = env[-1]
159+
assert len(result) == len(tune_params["block_size_x"])
160+
result, _ = tune_kernel(
161+
*env,
162+
smem_args=dict(size=lambda p: p['block_size_x'] * 4),
163+
verbose=True, lang="HIP")
164+
tune_params = env[-1]
165+
assert len(result) == len(tune_params["block_size_x"])
166+
130167

0 commit comments

Comments
 (0)