|
1 | 1 | import numpy as np |
2 | 2 | import ctypes |
3 | 3 | from .context import skip_if_no_pyhip |
| 4 | +from collections import OrderedDict |
4 | 5 |
|
5 | 6 | import pytest |
6 | 7 | import kernel_tuner |
| 8 | +from kernel_tuner import tune_kernel |
7 | 9 | from kernel_tuner.backends import hip as kt_hip |
8 | 10 | from kernel_tuner.core import KernelSource, KernelInstance |
9 | 11 |
|
|
13 | 15 | except ImportError: |
14 | 16 | pass |
15 | 17 |
|
| 18 | +@pytest.fixture |
| 19 | +def env(): |
| 20 | + kernel_string = """ |
| 21 | + extern "C" __global__ void vector_add(float *c, float *a, float *b, int n) { |
| 22 | + int i = blockIdx.x * block_size_x + threadIdx.x; |
| 23 | + if (i<n) { |
| 24 | + c[i] = a[i] + b[i]; |
| 25 | + } |
| 26 | + } |
| 27 | + """ |
| 28 | + |
| 29 | + size = 100 |
| 30 | + a = np.random.randn(size).astype(np.float32) |
| 31 | + b = np.random.randn(size).astype(np.float32) |
| 32 | + c = np.zeros_like(b) |
| 33 | + n = np.int32(size) |
| 34 | + |
| 35 | + args = [c, a, b, n] |
| 36 | + tune_params = OrderedDict() |
| 37 | + tune_params["block_size_x"] = [128 + 64 * i for i in range(15)] |
| 38 | + |
| 39 | + return ["vector_add", kernel_string, size, args, tune_params] |
| 40 | + |
16 | 41 | @skip_if_no_pyhip |
17 | 42 | def test_ready_argument_list(): |
18 | 43 |
|
@@ -125,6 +150,18 @@ def test_copy_constant_memory_args(): |
125 | 150 |
|
126 | 151 | assert(my_constant_data == output).all() |
127 | 152 |
|
128 | | -def dummy_func(a, b, block=0, grid=0, stream=None, shared=0, texrefs=None): |
129 | | - pass |
| 153 | +@skip_if_no_pyhip |
| 154 | +def test_smem_args(env): |
| 155 | + result, _ = tune_kernel(*env, |
| 156 | + smem_args=dict(size="block_size_x*4"), |
| 157 | + verbose=True, lang="HIP") |
| 158 | + tune_params = env[-1] |
| 159 | + assert len(result) == len(tune_params["block_size_x"]) |
| 160 | + result, _ = tune_kernel( |
| 161 | + *env, |
| 162 | + smem_args=dict(size=lambda p: p['block_size_x'] * 4), |
| 163 | + verbose=True, lang="HIP") |
| 164 | + tune_params = env[-1] |
| 165 | + assert len(result) == len(tune_params["block_size_x"]) |
| 166 | + |
130 | 167 |
|
0 commit comments