Skip to content

Commit 16443b4

Browse files
committed
GPU and Compiler backends have a different way to refresh memory.
1 parent 9f07354 commit 16443b4

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

kernel_tuner/backends/backend.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,10 @@ def memcpy_htod(self, dest, src):
5757
"""This method must implement a host to device copy."""
5858
pass
5959

60-
def refresh_memory(self, arguments, should_sync):
61-
"""Copy the original content of the output memory to device memory."""
62-
for i, arg in enumerate(arguments):
63-
if should_sync[i]:
64-
self.memcpy_htod(self.allocations[i], arg)
60+
@abstractmethod
61+
def refresh_memory(self, device_memory, host_arguments, should_sync):
62+
"""This method must implement refreshing the device memory with a clean copy."""
63+
pass
6564

6665

6766
class GPUBackend(Backend):
@@ -86,6 +85,12 @@ def copy_texture_memory_args(self, texmem_args):
8685
"""This method must implement the allocation and copy of texture memory to the GPU."""
8786
pass
8887

88+
def refresh_memory(self, gpu_memory, host_arguments, should_sync):
89+
"""Refresh the GPU memory with the untouched host arguments."""
90+
for i, arg in enumerate(host_arguments):
91+
if should_sync[i]:
92+
self.memcpy_htod(gpu_memory[i], arg)
93+
8994

9095
class CompilerBackend(Backend):
9196
"""Base class for compiler backends"""
@@ -94,6 +99,12 @@ class CompilerBackend(Backend):
9499
def __init__(self, iterations, compiler_options, compiler):
95100
pass
96101

102+
def refresh_memory(self, gpu_memory, host_arguments, should_sync):
103+
"""Refresh the GPU memory with the untouched host arguments."""
104+
for i, arg in enumerate(host_arguments):
105+
if should_sync[i]:
106+
self.memcpy_htod(self.allocations[i], arg)
107+
97108
@abstractmethod
98109
def cleanup_lib(self):
99110
"""Unload the previously loaded shared library"""

kernel_tuner/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ def check_kernel_output(self, func, gpu_args, instance, answer, atol, verify, ve
489489

490490
# re-copy original contents of output arguments to GPU memory, to overwrite any changes
491491
# by earlier kernel runs
492-
self.dev.refresh_memory(instance.arguments, should_sync)
492+
self.dev.refresh_memory(gpu_args, instance.arguments, should_sync)
493493

494494
# run the kernel
495495
check = self.run_kernel(func, gpu_args, instance)

0 commit comments

Comments
 (0)