Skip to content

Commit 3b0e3b9

Browse files
committed
Add support for passing cupy arrays to "C" lang
1 parent 7a01408 commit 3b0e3b9

File tree

2 files changed

+195
-10
lines changed

2 files changed

+195
-10
lines changed

kernel_tuner/backends/compiler.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,22 @@
2121
SkippableFailure,
2222
)
2323

24+
try:
25+
import cupy as cp
26+
except ImportError:
27+
cp = None
28+
29+
30+
def is_cupy_array(array):
31+
"""Check if something is a cupy array."""
32+
return cp is not None and isinstance(array, cp.ndarray)
33+
34+
35+
def get_array_module(*args):
36+
"""Return the array module for arguments."""
37+
return np if cp is None else cp.get_array_module(*args)
38+
39+
2440
dtype_map = {
2541
"int8": C.c_int8,
2642
"int16": C.c_int16,
@@ -112,9 +128,9 @@ def ready_argument_list(self, arguments):
112128
ctype_args = [None for _ in arguments]
113129

114130
for i, arg in enumerate(arguments):
115-
if not isinstance(arg, (np.ndarray, np.number)):
131+
if not (isinstance(arg, (np.ndarray, np.number)) or is_cupy_array(arg)):
116132
raise TypeError(
117-
"Argument is not numpy ndarray or numpy scalar %s" % type(arg)
133+
f"Argument is not numpy or cupy ndarray or numpy scalar but a {type(arg)}"
118134
)
119135
dtype_str = str(arg.dtype)
120136
if isinstance(arg, np.ndarray):
@@ -129,6 +145,8 @@ def ready_argument_list(self, arguments):
129145
raise TypeError("unknown dtype for ndarray")
130146
elif isinstance(arg, np.generic):
131147
data_ctypes = dtype_map[dtype_str](arg)
148+
elif is_cupy_array(arg):
149+
data_ctypes = C.c_void_p(arg.data.ptr)
132150
ctype_args[i] = Argument(numpy=arg, ctypes=data_ctypes)
133151
return ctype_args
134152

@@ -326,7 +344,10 @@ def memset(self, allocation, value, size):
326344
:param size: The size of to the allocation unit in bytes
327345
:type size: int
328346
"""
329-
C.memset(allocation.ctypes, value, size)
347+
if is_cupy_array(allocation.numpy):
348+
cp.cuda.runtime.memset(allocation.numpy.data.ptr, value, size)
349+
else:
350+
C.memset(allocation.ctypes, value, size)
330351

331352
def memcpy_dtoh(self, dest, src):
332353
"""a simple memcpy copying from an Argument to a numpy array
@@ -337,18 +358,30 @@ def memcpy_dtoh(self, dest, src):
337358
:param src: An Argument for some memory allocation
338359
:type src: Argument
339360
"""
340-
dest[:] = src.numpy
361+
if isinstance(dest, np.ndarray) and isinstance(src.numpy, cp.ndarray):
362+
# Implicit conversion to a NumPy array is not allowed.
363+
value = src.numpy.get()
364+
else:
365+
value = src.numpy
366+
xp = get_array_module(dest)
367+
dest[:] = xp.asarray(value)
341368

342369
def memcpy_htod(self, dest, src):
343370
"""a simple memcpy copying from a numpy array to an Argument
344371
345372
:param dest: An Argument for some memory allocation
346-
:type dst: Argument
373+
:type dest: Argument
347374
348375
:param src: A numpy array containing the source data
349376
:type src: np.ndarray
350377
"""
351-
dest.numpy[:] = src
378+
if isinstance(dest.numpy, np.ndarray) and isinstance(src, cp.ndarray):
379+
# Implicit conversion to a NumPy array is not allowed.
380+
value = src.get()
381+
else:
382+
value = src
383+
xp = get_array_module(dest.numpy)
384+
dest.numpy[:] = xp.asarray(value)
352385

353386
def cleanup_lib(self):
354387
"""unload the previously loaded shared library"""

test/test_compiler_functions.py

Lines changed: 156 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111
from unittest.mock import patch, Mock
1212

1313
import kernel_tuner
14-
from kernel_tuner.backends.compiler import CompilerFunctions, Argument
14+
from kernel_tuner.backends.compiler import CompilerFunctions, Argument, is_cupy_array, get_array_module
1515
from kernel_tuner.core import KernelSource, KernelInstance
1616
from kernel_tuner import util
1717

18-
from .context import skip_if_no_gfortran, skip_if_no_gcc, skip_if_no_openmp
18+
from .context import skip_if_no_gfortran, skip_if_no_gcc, skip_if_no_openmp, skip_if_no_cupy
19+
from .test_runners import env as cuda_env # noqa: F401
1920

2021

2122
@skip_if_no_gcc
@@ -108,6 +109,29 @@ def test_ready_argument_list5():
108109
assert all(output[0].numpy == arg1)
109110

110111

112+
@skip_if_no_cupy
113+
def test_ready_argument_list6():
114+
import cupy as cp
115+
116+
arg = cp.array([1, 2, 3], dtype=np.float32)
117+
arguments = [arg]
118+
119+
cfunc = CompilerFunctions()
120+
output = cfunc.ready_argument_list(arguments)
121+
print(output)
122+
123+
assert len(output) == 1
124+
assert output[0].numpy is arg
125+
mem = cp.cuda.UnownedMemory(
126+
ptr=output[0].ctypes.value,
127+
size=int(arg.nbytes / arg.dtype.itemsize),
128+
owner=None,
129+
)
130+
ptr = cp.cuda.MemoryPointer(mem, 0)
131+
output_arg = cp.ndarray(shape=arg.shape, dtype=arg.dtype, memptr=ptr)
132+
assert cp.all(output_arg == arg)
133+
134+
111135
@skip_if_no_gcc
112136
def test_byte_array_arguments():
113137
arg1 = np.array([1, 2, 3]).astype(np.int8)
@@ -206,8 +230,29 @@ def test_memset():
206230
assert all(x == np.zeros(4))
207231

208232

209-
@skip_if_no_gcc
233+
@skip_if_no_cupy
210234
def test_memcpy_dtoh():
235+
import cupy as cp
236+
237+
a = [1, 2, 3, 4]
238+
x = cp.asarray(a, dtype=np.float32)
239+
x_c = C.c_void_p(x.data.ptr)
240+
arg = Argument(numpy=x, ctypes=x_c)
241+
output = np.zeros(len(x), dtype=x.dtype)
242+
243+
cfunc = CompilerFunctions()
244+
cfunc.memcpy_dtoh(output, arg)
245+
246+
print(f"{type(x)=} {x=}")
247+
print(f"{type(a)=} {a=}")
248+
print(f"{type(output)=} {output=}")
249+
250+
assert all(output == a)
251+
assert all(x.get() == a)
252+
253+
254+
@skip_if_no_gcc
255+
def test_memcpy_host_dtoh():
211256
a = [1, 2, 3, 4]
212257
x = np.array(a).astype(np.float32)
213258
x_c = x.ctypes.data_as(C.POINTER(C.c_float))
@@ -224,8 +269,44 @@ def test_memcpy_dtoh():
224269
assert all(x == a)
225270

226271

227-
@skip_if_no_gcc
272+
@skip_if_no_cupy
273+
def test_memcpy_device_dtoh():
274+
import cupy as cp
275+
276+
a = [1, 2, 3, 4]
277+
x = cp.asarray(a, dtype=np.float32)
278+
x_c = C.c_void_p(x.data.ptr)
279+
arg = Argument(numpy=x, ctypes=x_c)
280+
output = cp.zeros_like(x)
281+
282+
cfunc = CompilerFunctions()
283+
cfunc.memcpy_dtoh(output, arg)
284+
285+
print(f"{type(x)=} {x=}")
286+
print(f"{type(a)=} {a=}")
287+
print(f"{type(output)=} {output=}")
288+
289+
assert all(output.get() == a)
290+
assert all(x.get() == a)
291+
292+
293+
@skip_if_no_cupy
228294
def test_memcpy_htod():
295+
import cupy as cp
296+
297+
a = [1, 2, 3, 4]
298+
src = np.array(a, dtype=np.float32)
299+
x = cp.zeros(len(src), dtype=src.dtype)
300+
x_c = C.c_void_p(x.data.ptr)
301+
arg = Argument(numpy=x, ctypes=x_c)
302+
303+
cfunc = CompilerFunctions()
304+
cfunc.memcpy_htod(arg, src)
305+
306+
assert all(arg.numpy.get() == a)
307+
308+
309+
def test_memcpy_host_htod():
229310
a = [1, 2, 3, 4]
230311
src = np.array(a).astype(np.float32)
231312
x = np.zeros_like(src)
@@ -238,6 +319,22 @@ def test_memcpy_htod():
238319
assert all(arg.numpy == a)
239320

240321

322+
@skip_if_no_cupy
323+
def test_memcpy_device_htod():
324+
import cupy as cp
325+
326+
a = [1, 2, 3, 4]
327+
src = cp.array(a, dtype=np.float32)
328+
x = cp.zeros(len(src), dtype=src.dtype)
329+
x_c = C.c_void_p(x.data.ptr)
330+
arg = Argument(numpy=x, ctypes=x_c)
331+
332+
cfunc = CompilerFunctions()
333+
cfunc.memcpy_htod(arg, src)
334+
335+
assert all(arg.numpy.get() == a)
336+
337+
241338
@skip_if_no_gfortran
242339
def test_complies_fortran_function_no_module():
243340
kernel_string = """
@@ -335,3 +432,58 @@ def test_benchmark(env):
335432
assert all(["nthreads" in result for result in results])
336433
assert all(["time" in result for result in results])
337434
assert all([result["time"] > 0.0 for result in results])
435+
436+
437+
@skip_if_no_cupy
438+
def test_is_cupy_array():
439+
import cupy as cp
440+
441+
assert is_cupy_array(cp.array([1.0]))
442+
assert not is_cupy_array(np.array([1.0]))
443+
444+
445+
def test_is_cupy_array_no_cupy():
446+
assert not is_cupy_array(np.array([1.0]))
447+
448+
449+
@skip_if_no_cupy
450+
def test_get_array_module():
451+
import cupy as cp
452+
453+
assert get_array_module(cp.array([1.0])) == cp
454+
assert get_array_module(np.array([1.0])) == np
455+
456+
457+
@skip_if_no_cupy
458+
@skip_if_no_gcc
459+
def test_run_kernel():
460+
import cupy as cp
461+
462+
kernel_string = """
463+
__global__ void vector_add_kernel(float *c, const float *a, const float *b, int n) {
464+
int i = blockIdx.x * block_size_x + threadIdx.x;
465+
if (i<n) {
466+
c[i] = a[i] + b[i];
467+
}
468+
}
469+
470+
extern "C" void vector_add(float *c, const float *a, const float *b, int n) {
471+
dim3 dimGrid(n);
472+
dim3 dimBlock(block_size_x);
473+
vector_add_kernel<<<dimGrid, dimBlock>>>(c, a, b, n);
474+
}
475+
"""
476+
a = cp.asarray([1, 2.0], dtype=np.float32)
477+
b = cp.asarray([3, 4.0], dtype=np.float32)
478+
c = cp.zeros_like(b)
479+
n = np.int32(len(c))
480+
481+
result = kernel_tuner.run_kernel(
482+
kernel_name="vector_add",
483+
kernel_source=kernel_string,
484+
problem_size=n,
485+
arguments=[c, a, b, n],
486+
params={"block_size_x": 1},
487+
lang="C",
488+
)
489+
assert cp.all((a + b) == c)

0 commit comments

Comments
 (0)