Skip to content

Commit 165ac8d

Browse files
author
Diptorup Deb
committed
Add experimental kernel decorator to numba jit_registry
1 parent 5d04c14 commit 165ac8d

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

numba_dpex/experimental/decorators.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import inspect
99

1010
from numba.core import sigutils
11+
from numba.core.target_extension import jit_registry, target_registry
1112

1213
from .kernel_dispatcher import KernelDispatcher
1314

@@ -78,3 +79,6 @@ def _specialized_kernel_dispatcher(pyfunc):
7879
"the return type as void explicitly."
7980
)
8081
return _kernel_dispatcher(func)
82+
83+
84+
jit_registry[target_registry["dpex_kernel"]] = kernel

numba_dpex/tests/experimental/test_exec_queue_inference.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import dpctl
77
import dpnp
88
import pytest
9+
from numba.core import config
910

1011
import numba_dpex.experimental as exp_dpex
1112
from numba_dpex import Range
@@ -34,14 +35,16 @@ def test_successful_execution_queue_inference():
3435
c = dpnp.zeros_like(a, sycl_queue=q)
3536
r = Range(100)
3637

37-
# FIXME: This test fails unexpectedly if the NUMBA_CAPTURED_ERRORS is set
38-
# to "new_style".
39-
# Refer: https://github.com/IntelPython/numba-dpex/issues/1195
38+
current_captured_error_style = config.CAPTURED_ERRORS
39+
config.CAPTURED_ERRORS = "new_style"
40+
4041
try:
4142
exp_dpex.call_kernel(add, r, a, b, c)
4243
except:
4344
pytest.fail("Unexpected error when calling kernel")
4445

46+
config.CAPTURED_ERRORS = current_captured_error_style
47+
4548
assert c[0] == b[0] + a[0]
4649

4750

@@ -59,8 +62,6 @@ def test_execution_queue_inference_error():
5962
c = dpnp.zeros_like(a, sycl_queue=q1)
6063
r = Range(100)
6164

62-
from numba.core import config
63-
6465
current_captured_error_style = config.CAPTURED_ERRORS
6566
config.CAPTURED_ERRORS = "new_style"
6667

0 commit comments

Comments
 (0)