Skip to content

Commit 4c77414

Browse files
committed
Although semantically there is no dtoh copy in the compiler backend, a copy of some kind is still needed. Plus a test.
1 parent 4587539 commit 4c77414

File tree

2 files changed

+30
-10
lines changed

2 files changed

+30
-10
lines changed

kernel_tuner/backends/compiler.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -356,24 +356,33 @@ def memset(self, allocation, value, size):
356356
C.memset(allocation.ctypes, value, size)
357357

358358
def memcpy_dtoh(self, dest, src):
359-
"""There is no memcpy_dtoh for the compiler backend."""
360-
pass
359+
"""This method implements the semantic of a device to host copy for the Compiler backend.
360+
There is no actual copy from device to host happening, but host to host.
361+
362+
:param dest: A numpy or cupy array to store the data
363+
:type dest: np.ndarray or cupy.ndarray
364+
365+
:param src: An Argument for some memory allocation
366+
:type src: Argument
367+
"""
368+
# there is no real copy from device to host, but host to host
369+
if isinstance(dest, np.ndarray) and is_cupy_array(src.numpy):
370+
# Implicit conversion to a NumPy array is not allowed.
371+
value = src.numpy.get()
372+
else:
373+
value = src.numpy
374+
xp = get_array_module(dest)
375+
dest[:] = xp.asarray(value)
361376

362377
def memcpy_htod(self, dest, src):
363-
"""There is no memcpy_htod for the compiler backend."""
378+
"""There is no memcpy_htod implemented for the compiler backend."""
364379
pass
365380

366381
def refresh_memory(self, arguments, should_sync):
367382
"""Copy the preserved content of the output memory to used arrays."""
368383
for i, arg in enumerate(arguments):
369384
if should_sync[i]:
370-
if isinstance(arg, np.ndarray) and is_cupy_array(self.allocations[i].numpy):
371-
# Implicit conversion to a NumPy array is not allowed.
372-
value = self.allocations[i].numpy.get()
373-
else:
374-
value = self.allocations[i].numpy
375-
xp = get_array_module(arg)
376-
arg[:] = xp.asarray(value)
385+
self.memcpy_dtoh(arg, self.allocations[i])
377386

378387
def cleanup_lib(self):
379388
"""Unload the previously loaded shared library"""

test/test_compiler_functions.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,3 +380,14 @@ def test_refresh_memory():
380380
assert np.all(arguments[0] == [0, 0, 0])
381381
cfunc.refresh_memory(arguments, [True])
382382
assert np.all(arguments[0] == [1, 2, 3])
383+
384+
385+
def test_memcpy_dtoh():
386+
arg1 = np.array([0, 5, 0, 7]).astype(np.int32)
387+
arguments = [arg1]
388+
cfunc = CompilerFunctions()
389+
ready_arguments = cfunc.ready_argument_list(arguments)
390+
expected = np.array([0, 0, 0, 0]).astype(np.float32)
391+
assert np.all(ready_arguments.numpy != expected)
392+
cfunc.memcpy_dtoh(expected, ready_arguments)
393+
assert np.all(ready_arguments.numpy == expected)

0 commit comments

Comments
 (0)