Skip to content

Commit 943b3c4

Browse files
committed
Added test for RegisterObserver, added clause in case of mocktest
1 parent 81a68a4 commit 943b3c4

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

kernel_tuner/backends/pycuda.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,8 @@ def compile(self, kernel_instance):
218218
)
219219

220220
self.func = self.current_module.get_function(kernel_name)
221-
self.num_regs = self.func.num_regs
221+
if not isinstance(self.func, str):
222+
self.num_regs = self.func.num_regs
222223
return self.func
223224
except drv.CompileError as e:
224225
if "uses too much shared data" in e.stderr:

test/test_observers.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
import kernel_tuner
44
from kernel_tuner.observers.nvml import NVMLObserver
5+
from kernel_tuner.observers.register import RegisterObserver
56
from kernel_tuner.observers.observer import BenchmarkObserver
67

7-
from .context import skip_if_no_pycuda, skip_if_no_pynvml
8+
from .context import skip_if_no_pycuda, skip_if_no_pynvml, skip_if_no_cupy, skip_if_no_cuda
89
from .test_runners import env # noqa: F401
910

1011

@@ -20,6 +21,29 @@ def test_nvml_observer(env):
2021
assert "temperature" in result[0]
2122
assert result[0]["temperature"] > 0
2223

24+
@skip_if_no_pycuda
25+
def test_register_observer_pycuda(env):
26+
registerobserver = RegisterObserver()
27+
env[-1]["block_size_x"] = [128]
28+
result, _ = kernel_tuner.tune_kernel(*env, observers=[registerobserver], lang='CUDA')
29+
assert "num_regs" in result[0]
30+
assert result[0]["num_regs"] > 0
31+
32+
@skip_if_no_cupy
33+
def test_register_observer_cupy(env):
34+
registerobserver = RegisterObserver()
35+
env[-1]["block_size_x"] = [128]
36+
result, _ = kernel_tuner.tune_kernel(*env, observers=[registerobserver], lang='CuPy')
37+
assert "num_regs" in result[0]
38+
assert result[0]["num_regs"] > 0
39+
40+
@skip_if_no_cuda
41+
def test_register_observer_nvcuda(env):
42+
registerobserver = RegisterObserver()
43+
env[-1]["block_size_x"] = [128]
44+
result, _ = kernel_tuner.tune_kernel(*env, observers=[registerobserver], lang='NVCUDA')
45+
assert "num_regs" in result[0]
46+
assert result[0]["num_regs"] > 0
2347

2448
@skip_if_no_pycuda
2549
def test_custom_observer(env):

0 commit comments

Comments
 (0)