Skip to content

Commit aa3cadb

Browse files
committed
Added test for the compiler memory refresh.
1 parent 2333916 commit aa3cadb

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

kernel_tuner/backends/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def memcpy_htod(self, dest, src):
364364
pass
365365

366366
def refresh_memory(self, arguments, should_sync):
367-
"""Copy the preserved content of the output memory to device pointers."""
367+
"""Copy the preserved content of the output memory to used arrays."""
368368
for i, arg in enumerate(arguments):
369369
if should_sync[i]:
370370
if isinstance(arg, np.ndarray) and is_cupy_array(self.allocations[i].numpy):

test/test_compiler_functions.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,3 +368,14 @@ def test_run_kernel():
368368
lang="C",
369369
)
370370
assert cp.all((a + b) == c)
371+
372+
373+
def test_refresh_memory():
374+
arg1 = np.array([1, 2, 3]).astype(np.int8)
375+
cfunc = CompilerFunctions()
376+
output = cfunc.ready_argument_list([arg1])
377+
assert np.all(output == arg1)
378+
arg1 = np.array([0, 0, 0]).astype(np.int8)
379+
assert np.all(arg1 == [0, 0, 0])
380+
cfunc.refresh_memory(arg1, [True])
381+
assert np.all(arg1 == [1, 2, 3])

0 commit comments

Comments
 (0)