Skip to content

Commit 145f980

Browse files
author
Diptorup Deb
authored
Merge pull request #1064 from IntelPython/change/DpctlSyclQueue_type
Changes to the DpctlSyclQueue and USMNdArray types.
2 parents a1ee24e + bbd6132 commit 145f980

27 files changed

+431
-407
lines changed

numba_dpex/core/kernel_interface/dispatcher.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,9 +409,14 @@ def __call__(self, *args):
409409
# FIXME: For specialized and ahead of time compiled and cached kernels,
410410
# the CFD check was already done statically. The run-time check is
411411
# redundant. We should avoid these checks for the specialized case.
412-
exec_queue = determine_kernel_launch_queue(
412+
ty_queue = determine_kernel_launch_queue(
413413
args, argtypes, self.kernel_name
414414
)
415+
416+
# FIXME: We need a better way than having to create a queue every time.
417+
device = ty_queue.sycl_device
418+
exec_queue = dpctl.get_device_cached_queue(device)
419+
415420
backend = exec_queue.backend
416421

417422
if exec_queue.backend not in [

numba_dpex/core/parfors/kernel_builder.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import sys
77
import warnings
88

9+
import dpctl
910
import dpctl.program as dpctl_prog
1011
from numba.core import ir, types
1112
from numba.core.errors import NumbaParallelSafetyWarning
@@ -426,7 +427,10 @@ def create_kernel_for_parfor(
426427
for arg in parfor_args:
427428
obj = typemap[arg]
428429
if isinstance(obj, DpnpNdArray):
429-
exec_queue = obj.queue
430+
filter_string = obj.queue.sycl_device
431+
# FIXME: A better design is required so that we do not have to
432+
# create a queue every time.
433+
exec_queue = dpctl.get_device_cached_queue(filter_string)
430434

431435
if not exec_queue:
432436
raise AssertionError(

numba_dpex/core/parfors/reduction_kernel_builder.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import warnings
66

7+
import dpctl
78
from numba.core import types
89
from numba.core.errors import NumbaParallelSafetyWarning
910
from numba.core.ir_utils import (
@@ -18,6 +19,8 @@
1819
)
1920
from numba.core.typing import signature
2021

22+
from numba_dpex.core.types import DpctlSyclQueue
23+
2124
from ..utils.kernel_templates.reduction_template import (
2225
RemainderReduceIntermediateKernelTemplate,
2326
TreeReduceIntermediateKernelTemplate,
@@ -134,7 +137,13 @@ def create_reduction_main_kernel_for_parfor(
134137
flags.noalias = True
135138

136139
kernel_sig = signature(types.none, *kernel_param_types)
137-
exec_queue = typemap[reductionKernelVar.parfor_params[0]].queue
140+
141+
# FIXME: A better design is required so that we do not have to create a
142+
# queue every time.
143+
ty_queue: DpctlSyclQueue = typemap[
144+
reductionKernelVar.parfor_params[0]
145+
].queue
146+
exec_queue = dpctl.get_device_cached_queue(ty_queue.sycl_device)
138147

139148
sycl_kernel = _compile_kernel_parfor(
140149
exec_queue,
@@ -331,11 +340,12 @@ def create_reduction_remainder_kernel_for_parfor(
331340

332341
kernel_sig = signature(types.none, *kernel_param_types)
333342

334-
# FIXME: Enable check after CFD pass has been added
335-
# exec_queue = determine_kernel_launch_queue(
336-
# args=parfor_args, argtypes=kernel_param_types, kernel_name=kernel_name
337-
# )
338-
exec_queue = typemap[reductionKernelVar.parfor_params[0]].queue
343+
# FIXME: A better design is required so that we do not have to create a
344+
# queue every time.
345+
ty_queue: DpctlSyclQueue = typemap[
346+
reductionKernelVar.parfor_params[0]
347+
].queue
348+
exec_queue = dpctl.get_device_cached_queue(ty_queue.sycl_device)
339349

340350
sycl_kernel = _compile_kernel_parfor(
341351
exec_queue,

numba_dpex/core/runtime/_dpexrt_python.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1180,7 +1180,7 @@ static int DPEXRT_sycl_queue_from_python(PyObject *obj,
11801180
PyGILState_STATE gstate;
11811181

11821182
// Increment the ref count on obj to prevent CPython from garbage
1183-
// collecting the array.
1183+
// collecting the dpctl.SyclQueue object
11841184
Py_IncRef(obj);
11851185

11861186
// We are unconditionally casting obj to a struct PySyclQueueObject*. If

numba_dpex/core/typeconv/array_conversion.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44

55
from numba.np import numpy_support
66

7-
from numba_dpex.core.types import USMNdArray
8-
from numba_dpex.core.utils import get_info_from_suai
7+
from numba_dpex.core.types import DpctlSyclQueue, USMNdArray
98
from numba_dpex.utils.constants import address_space
109

1110

@@ -37,7 +36,7 @@ def to_usm_ndarray(suai_attrs, addrspace=address_space.GLOBAL):
3736
ndim=suai_attrs.dimensions,
3837
layout=layout,
3938
usm_type=suai_attrs.usm_type,
40-
queue=suai_attrs.queue,
39+
queue=DpctlSyclQueue(suai_attrs.queue),
4140
readonly=not suai_attrs.is_writable,
4241
name=None,
4342
aligned=True,

numba_dpex/core/types/dpctl_types.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,18 @@
1414

1515

1616
class DpctlSyclQueue(types.Type):
17-
"""A Numba type to represent a dpctl.SyclQueue PyObject.
18-
19-
For now, a dpctl.SyclQueue is represented as a Numba opaque type that allows
20-
passing in and using a SyclQueue object as an opaque pointer type inside
21-
Numba.
22-
"""
17+
"""A Numba type to represent a dpctl.SyclQueue PyObject."""
2318

2419
def __init__(self, sycl_queue):
2520
if not isinstance(sycl_queue, SyclQueue):
2621
raise TypeError("The argument sycl_queue is not of type SyclQueue.")
2722

28-
self._sycl_queue = sycl_queue
23+
# XXX: Storing the device filter string is a temporary workaround till
24+
# the compute follows data inference pass is fixed to use SyclQueue
25+
self._device = sycl_queue.sycl_device.filter_string
26+
2927
try:
30-
self._unique_id = hash(self._sycl_queue)
28+
self._unique_id = hash(sycl_queue)
3129
except Exception:
3230
self._unique_id = self.rand_digit_str(16)
3331
super(DpctlSyclQueue, self).__init__(name="DpctlSyclQueue")
@@ -38,8 +36,14 @@ def rand_digit_str(self, n):
3836
)
3937

4038
@property
41-
def sycl_queue(self):
42-
return self._sycl_queue
39+
def sycl_device(self):
40+
"""Returns the SYCL oneAPI extension filter string associated with the
41+
queue.
42+
43+
Returns:
44+
str: A SYCL oneAPI extension filter string
45+
"""
46+
return self._device
4347

4448
@property
4549
def key(self):
@@ -69,11 +73,8 @@ def unbox_sycl_queue(typ, obj, c):
6973
qptr = qstruct._getpointer()
7074
ptr = c.builder.bitcast(qptr, c.pyapi.voidptr)
7175

72-
if c.context.enable_nrt:
73-
dpexrtCtx = dpexrt.DpexRTContext(c.context)
74-
errcode = dpexrtCtx.queuestruct_from_python(c.pyapi, obj, ptr)
75-
else:
76-
raise UnreachableError
76+
dpexrtCtx = dpexrt.DpexRTContext(c.context)
77+
errcode = dpexrtCtx.queuestruct_from_python(c.pyapi, obj, ptr)
7778
is_error = cgutils.is_not_null(c.builder, errcode)
7879

7980
# Handle error

numba_dpex/core/types/usm_ndarray_type.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from numba.core.types.npytypes import Array
1313
from numba.np.numpy_support import from_dtype
1414

15+
from numba_dpex.core.types.dpctl_types import DpctlSyclQueue
1516
from numba_dpex.utils import address_space
1617

1718

@@ -31,22 +32,28 @@ def __init__(
3132
aligned=True,
3233
addrspace=address_space.GLOBAL,
3334
):
34-
if queue and not isinstance(queue, types.misc.Omitted) and device:
35+
if (
36+
queue is not None
37+
and not (
38+
isinstance(queue, types.misc.Omitted)
39+
or isinstance(queue, types.misc.NoneType)
40+
)
41+
and device is not None
42+
):
3543
raise TypeError(
3644
"numba_dpex.core.types.usm_ndarray_type.USMNdArray.__init__(): "
37-
"`device` and `sycl_queue` are exclusive keywords, i.e. use one or other."
45+
"`device` and `sycl_queue` are exclusive keywords, "
46+
"i.e. use one or other."
3847
)
3948

40-
self.usm_type = usm_type
41-
self.addrspace = addrspace
42-
43-
if queue and not isinstance(queue, types.misc.Omitted):
44-
if not isinstance(queue, dpctl.SyclQueue):
49+
if queue is not None and not (
50+
isinstance(queue, types.misc.Omitted)
51+
or isinstance(queue, types.misc.NoneType)
52+
):
53+
if not isinstance(queue, DpctlSyclQueue):
4554
raise TypeError(
46-
"numba_dpex.core.types.usm_ndarray_type.USMNdArray.__init__(): "
47-
"The queue keyword arg should be a dpctl.SyclQueue object or None."
48-
"Found type(queue) ="
49-
+ str(type(queue) + " and queue =" + queue)
55+
"The queue keyword arg should be either DpctlSyclQueue or "
56+
"NoneType. Found type(queue) = " + str(type(queue))
5057
)
5158
self.queue = queue
5259
else:
@@ -55,24 +62,23 @@ def __init__(
5562
else:
5663
if not isinstance(device, str):
5764
raise TypeError(
58-
"numba_dpex.core.types.usm_ndarray_type.USMNdArray.__init__(): "
59-
"The device keyword arg should be a str object specifying "
60-
"a SYCL filter selector."
65+
"The device keyword arg should be a str object "
66+
"specifying a SYCL filter selector."
6167
)
6268
sycl_device = dpctl.SyclDevice(device)
6369

64-
self.queue = dpctl._sycl_queue_manager.get_device_cached_queue(
70+
sycl_queue = dpctl._sycl_queue_manager.get_device_cached_queue(
6571
sycl_device
6672
)
73+
self.queue = DpctlSyclQueue(sycl_queue=sycl_queue)
6774

68-
self.device = self.queue.sycl_device.filter_string
75+
self.device = self.queue.sycl_device
76+
self.usm_type = usm_type
77+
self.addrspace = addrspace
6978

7079
if not dtype:
7180
dummy_tensor = dpctl.tensor.empty(
72-
1,
73-
order=layout,
74-
usm_type=usm_type,
75-
sycl_queue=self.queue,
81+
1, order=layout, usm_type=usm_type, device=self.device
7682
)
7783
# convert dpnp type to numba/numpy type
7884
_dtype = dummy_tensor.dtype

numba_dpex/core/typing/typeof.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,18 @@ def _typeof_helper(val, array_class_type):
4242
"The usm_type for the usm_ndarray could not be inferred"
4343
)
4444

45-
assert val.sycl_queue is not None
45+
if not val.sycl_queue:
46+
raise AssertionError
47+
48+
ty_queue = DpctlSyclQueue(sycl_queue=val.sycl_queue)
4649

4750
return array_class_type(
4851
dtype=dtype,
4952
ndim=val.ndim,
5053
layout=layout,
5154
readonly=readonly,
5255
usm_type=usm_type,
53-
queue=val.sycl_queue,
56+
queue=ty_queue,
5457
addrspace=address_space.GLOBAL,
5558
)
5659

0 commit comments

Comments
 (0)