Skip to content

Commit e76b774

Browse files
Merge pull request #226 from bouweandela/add-compiler-cupy-array-support
Add support for passing cupy arrays to "C" lang
2 parents 66428e3 + 303ef3a commit e76b774

File tree

4 files changed

+311
-17
lines changed

4 files changed

+311
-17
lines changed

doc/source/hostcode.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ There are few differences with tuning just a single CUDA or OpenCL kernel, to li
1515
* You have to specify the lang="C" option
1616
* The C function should return a ``float``
1717
* You have to do your own timing and error handling in C
18+
* Data is not automatically copied to and from device memory. To use an array in host memory, pass in a :mod:`numpy` array. To use an array
19+
in device memory, pass in a :mod:`cupy` array.
1820

1921
You have to specify the language as "C" because the Kernel Tuner will be calling a host function. This means that the Kernel
2022
Tuner will have to interface with C and in fact uses a different backend. This also means you can use this way of tuning
@@ -94,7 +96,7 @@ compiled C code. This way, you don't have to compute the grid size in C, you can
9496

9597
The filter is not passed separately as a constant memory argument, because the CudaMemcpyToSymbol operation is now performed by the C host function. Also,
9698
because the code is compiled differently, we have no direct reference to the compiled module that is uploaded to the device and therefore we can not perform this
97-
operation directly from Python. If you are tuning host code, you have to perform all memory allocations, frees, and memcpy operations inside the C host code,
99+
operation directly from Python. If you are tuning host code, you have the option to perform all memory allocations, frees, and memcpy operations inside the C host code,
98100
that's the purpose of host code after all. That is also why you have to do the timing yourself in C, as you may not want to include the time spent on memory
99101
allocations and other setup into your time measurements.
100102

examples/cuda/pnpoly_cupy.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
#!/usr/bin/env python
2+
""" Point-in-Polygon host/device code tuner
3+
4+
This program is used for auto-tuning the host and device code of a CUDA program
5+
for computing the point-in-polygon problem for very large datasets and large
6+
polygons.
7+
8+
The time measurements used as a basis for tuning include the time spent on
9+
data transfers between host and device memory. The host code uses device mapped
10+
host memory to overlap communication between host and device with kernel
11+
execution on the GPU. Because each input is read only once and each output
12+
is written only once, this implementation almost fully overlaps all
13+
communication and the kernel execution time dominates the total execution time.
14+
15+
The code has the option to precompute all polygon line slopes on the CPU and
16+
reuse those results on the GPU, instead of recomputing them on the GPU all
17+
the time. The time spent on precomputing these values on the CPU is also
18+
taken into account by the time measurement in the code.
19+
20+
This code was written for use with the Kernel Tuner. See:
21+
https://github.com/benvanwerkhoven/kernel_tuner
22+
23+
Author: Ben van Werkhoven <[email protected]>
24+
"""
25+
from collections import OrderedDict
26+
import json
27+
import logging
28+
29+
import cupy as cp
30+
import cupyx as cpx
31+
import kernel_tuner
32+
import numpy
33+
34+
35+
def allocator(size: int) -> cp.cuda.PinnedMemoryPointer:
36+
"""Allocate context-portable device mapped host memory."""
37+
flags = cp.cuda.runtime.hostAllocPortable | cp.cuda.runtime.hostAllocMapped
38+
mem = cp.cuda.PinnedMemory(size, flags=flags)
39+
return cp.cuda.PinnedMemoryPointer(mem, offset=0)
40+
41+
42+
def tune():
43+
44+
#set the number of points and the number of vertices
45+
size = numpy.int32(2e7)
46+
problem_size = (size, 1)
47+
vertices = 600
48+
49+
#allocate context-portable device mapped host memory
50+
cp.cuda.set_pinned_memory_allocator(allocator)
51+
52+
#generate input data
53+
points = cpx.empty_pinned(shape=(2*size,), dtype=numpy.float32)
54+
points[:] = numpy.random.randn(2*size).astype(numpy.float32)
55+
56+
bitmap = cpx.zeros_pinned(shape=(size,), dtype=numpy.int32)
57+
#as test input we use a circle with radius 1 as polygon and
58+
#a large set of normally distributed points around 0,0
59+
vertex_seeds = numpy.sort(numpy.random.rand(vertices)*2.0*numpy.pi)[::-1]
60+
vertex_x = numpy.cos(vertex_seeds)
61+
vertex_y = numpy.sin(vertex_seeds)
62+
vertex_xy = cpx.empty_pinned(shape=(2*vertices,), dtype=numpy.float32)
63+
vertex_xy[:] = numpy.array( list(zip(vertex_x, vertex_y)) ).astype(numpy.float32).ravel()
64+
65+
#kernel arguments
66+
args = [bitmap, points, vertex_xy, size]
67+
68+
#setup tunable parameters
69+
tune_params = OrderedDict()
70+
tune_params["block_size_x"] = [32*i for i in range(1,32)] #multiple of 32
71+
tune_params["tile_size"] = [1] + [2*i for i in range(1,11)]
72+
tune_params["between_method"] = [0, 1, 2, 3]
73+
tune_params["use_precomputed_slopes"] = [0, 1]
74+
tune_params["use_method"] = [0, 1]
75+
76+
#tell the Kernel Tuner how to compute the grid dimensions from the problem_size
77+
grid_div_x = ["block_size_x", "tile_size"]
78+
79+
#start tuning
80+
results = kernel_tuner.tune_kernel("cn_pnpoly_host", ['pnpoly_host.cu', 'pnpoly.cu'],
81+
problem_size, args, tune_params,
82+
grid_div_x=grid_div_x, lang="C", compiler_options=["-arch=sm_52"], verbose=True, log=logging.DEBUG)
83+
84+
return results
85+
86+
87+
if __name__ == "__main__":
88+
results = tune()
89+
with open("pnpoly.json", 'w') as fp:
90+
json.dump(results, fp)

kernel_tuner/backends/compiler.py

Lines changed: 62 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,39 @@
2121
SkippableFailure,
2222
)
2323

24+
try:
25+
import cupy as cp
26+
except ImportError:
27+
cp = None
28+
29+
30+
def is_cupy_array(array):
31+
"""Check if something is a cupy array.
32+
33+
:param array: A Python object.
34+
:type array: typing.Any
35+
36+
:returns: True if cupy can be imported and the object is a cupy.ndarray.
37+
:rtype: bool
38+
"""
39+
return cp is not None and isinstance(array, cp.ndarray)
40+
41+
42+
def get_array_module(*args):
43+
"""Return the array module for arguments.
44+
45+
This function is used to implement CPU/GPU generic code. If the cupy module can be imported
46+
and at least one of the arguments is a cupy.ndarray object, the cupy module is returned.
47+
48+
:param args: Values to determine whether NumPy or CuPy should be used.
49+
:type args: numpy.ndarray or cupy.ndarray
50+
51+
:returns: cupy or numpy is returned based on the types of the arguments.
52+
:rtype: types.ModuleType
53+
"""
54+
return np if cp is None else cp.get_array_module(*args)
55+
56+
2457
dtype_map = {
2558
"int8": C.c_int8,
2659
"int16": C.c_int16,
@@ -103,18 +136,18 @@ def ready_argument_list(self, arguments):
103136
104137
:param arguments: List of arguments to be passed to the C function.
105138
The order should match the argument list on the C function.
106-
Allowed values are np.ndarray, and/or np.int32, np.float32, and so on.
107-
:type arguments: list(numpy objects)
139+
Allowed values are np.ndarray, cupy.ndarray, and/or np.int32, np.float32, and so on.
140+
:type arguments: list(numpy or cupy objects)
108141
109142
:returns: A list of arguments that can be passed to the C function.
110143
:rtype: list(Argument)
111144
"""
112145
ctype_args = [None for _ in arguments]
113146

114147
for i, arg in enumerate(arguments):
115-
if not isinstance(arg, (np.ndarray, np.number)):
148+
if not (isinstance(arg, (np.ndarray, np.number)) or is_cupy_array(arg)):
116149
raise TypeError(
117-
"Argument is not numpy ndarray or numpy scalar %s" % type(arg)
150+
f"Argument is not numpy or cupy ndarray or numpy scalar but a {type(arg)}"
118151
)
119152
dtype_str = str(arg.dtype)
120153
if isinstance(arg, np.ndarray):
@@ -129,6 +162,8 @@ def ready_argument_list(self, arguments):
129162
raise TypeError("unknown dtype for ndarray")
130163
elif isinstance(arg, np.generic):
131164
data_ctypes = dtype_map[dtype_str](arg)
165+
elif is_cupy_array(arg):
166+
data_ctypes = C.c_void_p(arg.data.ptr)
132167
ctype_args[i] = Argument(numpy=arg, ctypes=data_ctypes)
133168
return ctype_args
134169

@@ -326,29 +361,44 @@ def memset(self, allocation, value, size):
326361
:param size: The size of to the allocation unit in bytes
327362
:type size: int
328363
"""
329-
C.memset(allocation.ctypes, value, size)
364+
if is_cupy_array(allocation.numpy):
365+
cp.cuda.runtime.memset(allocation.numpy.data.ptr, value, size)
366+
else:
367+
C.memset(allocation.ctypes, value, size)
330368

331369
def memcpy_dtoh(self, dest, src):
332370
"""a simple memcpy copying from an Argument to a numpy array
333371
334-
:param dest: A numpy array to store the data
335-
:type dest: np.ndarray
372+
:param dest: A numpy or cupy array to store the data
373+
:type dest: np.ndarray or cupy.ndarray
336374
337375
:param src: An Argument for some memory allocation
338376
:type src: Argument
339377
"""
340-
dest[:] = src.numpy
378+
if isinstance(dest, np.ndarray) and is_cupy_array(src.numpy):
379+
# Implicit conversion to a NumPy array is not allowed.
380+
value = src.numpy.get()
381+
else:
382+
value = src.numpy
383+
xp = get_array_module(dest)
384+
dest[:] = xp.asarray(value)
341385

342386
def memcpy_htod(self, dest, src):
343387
"""a simple memcpy copying from a numpy array to an Argument
344388
345389
:param dest: An Argument for some memory allocation
346-
:type dst: Argument
390+
:type dest: Argument
347391
348-
:param src: A numpy array containing the source data
349-
:type src: np.ndarray
392+
:param src: A numpy or cupy array containing the source data
393+
:type src: np.ndarray or cupy.ndarray
350394
"""
351-
dest.numpy[:] = src
395+
if isinstance(dest.numpy, np.ndarray) and is_cupy_array(src):
396+
# Implicit conversion to a NumPy array is not allowed.
397+
value = src.get()
398+
else:
399+
value = src
400+
xp = get_array_module(dest.numpy)
401+
dest.numpy[:] = xp.asarray(value)
352402

353403
def cleanup_lib(self):
354404
"""unload the previously loaded shared library"""

0 commit comments

Comments
 (0)