Skip to content

Commit 81a68a4

Browse files
committed
Added RegisterObserver with common interface among backends
1 parent 4dbcb66 commit 81a68a4

File tree

4 files changed

+19
-0
lines changed

4 files changed

+19
-0
lines changed

kernel_tuner/backends/cupy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def compile(self, kernel_instance):
132132
)
133133

134134
self.func = self.current_module.get_function(kernel_name)
135+
self.num_regs = self.func.num_regs
135136
return self.func
136137

137138
def start_event(self):

kernel_tuner/backends/nvcuda.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,11 @@ def compile(self, kernel_instance):
192192
)
193193
cuda_error_check(err)
194194

195+
# get the number of registers per thread used in this kernel
196+
num_regs = cuda.cuFuncGetAttribute(cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NUM_REGS, self.func)
197+
assert num_regs[0] == 0, f"Retrieving number of registers per thread unsuccesful: code {num_regs[0]}"
198+
self.num_regs = num_regs[1]
199+
195200
except RuntimeError as re:
196201
_, n = nvrtc.nvrtcGetProgramLogSize(program)
197202
log = b" " * n

kernel_tuner/backends/pycuda.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def compile(self, kernel_instance):
218218
)
219219

220220
self.func = self.current_module.get_function(kernel_name)
221+
self.num_regs = self.func.num_regs
221222
return self.func
222223
except drv.CompileError as e:
223224
if "uses too much shared data" in e.stderr:

kernel_tuner/observers/register.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from kernel_tuner.observers.observer import BenchmarkObserver
2+
3+
class RegisterObserver(BenchmarkObserver):
4+
"""Observer for counting the number of registers."""
5+
6+
def __init__(self) -> None:
7+
super().__init__()
8+
9+
def get_results(self):
10+
return {
11+
"num_regs": self.dev.num_regs
12+
}

0 commit comments

Comments
 (0)