Skip to content

Commit bbd6132

Browse files
author
Diptorup Deb
committed
Changed the DpctlSyclQueue and USMNdArray types.
- Storing the Python dpctl.SyclQueue inside any instance of the DpctlSyclQueue type was causing segfaults due to the Python object getting garbage collected prematurely. The changes in the PR update the DpctlSyclQueue type to only store the filter string associated with the dpctl.SyclQueue and not the actual Python object. In addition, the USMNdArray type now stores an instance of a DpctlSyclQueue in its queue parameter instead of a Python dpctl.SyclQueue object. Due to these changes, all places where the Python dpctl.SyclQueue was getting extracted and used from a UsmNdArray instance or DpctlSyclQueue instance have been updated. All test cases were also updated.
1 parent 4b9aacf commit bbd6132

File tree

16 files changed

+368
-303
lines changed

16 files changed

+368
-303
lines changed

numba_dpex/core/kernel_interface/dispatcher.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,9 +409,14 @@ def __call__(self, *args):
409409
# FIXME: For specialized and ahead of time compiled and cached kernels,
410410
# the CFD check was already done statically. The run-time check is
411411
# redundant. We should avoid these checks for the specialized case.
412-
exec_queue = determine_kernel_launch_queue(
412+
ty_queue = determine_kernel_launch_queue(
413413
args, argtypes, self.kernel_name
414414
)
415+
416+
# FIXME: We need a better way than having to create a queue every time.
417+
device = ty_queue.sycl_device
418+
exec_queue = dpctl.get_device_cached_queue(device)
419+
415420
backend = exec_queue.backend
416421

417422
if exec_queue.backend not in [

numba_dpex/core/parfors/kernel_builder.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import sys
77
import warnings
88

9+
import dpctl
910
import dpctl.program as dpctl_prog
1011
from numba.core import ir, types
1112
from numba.core.errors import NumbaParallelSafetyWarning
@@ -426,7 +427,10 @@ def create_kernel_for_parfor(
426427
for arg in parfor_args:
427428
obj = typemap[arg]
428429
if isinstance(obj, DpnpNdArray):
429-
exec_queue = obj.queue
430+
filter_string = obj.queue.sycl_device
431+
# FIXME: A better design is required so that we do not have to
432+
# create a queue every time.
433+
exec_queue = dpctl.get_device_cached_queue(filter_string)
430434

431435
if not exec_queue:
432436
raise AssertionError(

numba_dpex/core/parfors/reduction_kernel_builder.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import warnings
66

7+
import dpctl
78
from numba.core import types
89
from numba.core.errors import NumbaParallelSafetyWarning
910
from numba.core.ir_utils import (
@@ -18,6 +19,8 @@
1819
)
1920
from numba.core.typing import signature
2021

22+
from numba_dpex.core.types import DpctlSyclQueue
23+
2124
from ..utils.kernel_templates.reduction_template import (
2225
RemainderReduceIntermediateKernelTemplate,
2326
TreeReduceIntermediateKernelTemplate,
@@ -134,7 +137,13 @@ def create_reduction_main_kernel_for_parfor(
134137
flags.noalias = True
135138

136139
kernel_sig = signature(types.none, *kernel_param_types)
137-
exec_queue = typemap[reductionKernelVar.parfor_params[0]].queue
140+
141+
# FIXME: A better design is required so that we do not have to create a
142+
# queue every time.
143+
ty_queue: DpctlSyclQueue = typemap[
144+
reductionKernelVar.parfor_params[0]
145+
].queue
146+
exec_queue = dpctl.get_device_cached_queue(ty_queue.sycl_device)
138147

139148
sycl_kernel = _compile_kernel_parfor(
140149
exec_queue,
@@ -331,11 +340,12 @@ def create_reduction_remainder_kernel_for_parfor(
331340

332341
kernel_sig = signature(types.none, *kernel_param_types)
333342

334-
# FIXME: Enable check after CFD pass has been added
335-
# exec_queue = determine_kernel_launch_queue(
336-
# args=parfor_args, argtypes=kernel_param_types, kernel_name=kernel_name
337-
# )
338-
exec_queue = typemap[reductionKernelVar.parfor_params[0]].queue
343+
# FIXME: A better design is required so that we do not have to create a
344+
# queue every time.
345+
ty_queue: DpctlSyclQueue = typemap[
346+
reductionKernelVar.parfor_params[0]
347+
].queue
348+
exec_queue = dpctl.get_device_cached_queue(ty_queue.sycl_device)
339349

340350
sycl_kernel = _compile_kernel_parfor(
341351
exec_queue,

numba_dpex/core/runtime/_dpexrt_python.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1180,7 +1180,7 @@ static int DPEXRT_sycl_queue_from_python(PyObject *obj,
11801180
PyGILState_STATE gstate;
11811181

11821182
// Increment the ref count on obj to prevent CPython from garbage
1183-
// collecting the array.
1183+
// collecting the dpctl.SyclQueue object
11841184
Py_IncRef(obj);
11851185

11861186
// We are unconditionally casting obj to a struct PySyclQueueObject*. If

numba_dpex/core/typeconv/array_conversion.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44

55
from numba.np import numpy_support
66

7-
from numba_dpex.core.types import USMNdArray
8-
from numba_dpex.core.utils import get_info_from_suai
7+
from numba_dpex.core.types import DpctlSyclQueue, USMNdArray
98
from numba_dpex.utils.constants import address_space
109

1110

@@ -37,7 +36,7 @@ def to_usm_ndarray(suai_attrs, addrspace=address_space.GLOBAL):
3736
ndim=suai_attrs.dimensions,
3837
layout=layout,
3938
usm_type=suai_attrs.usm_type,
40-
queue=suai_attrs.queue,
39+
queue=DpctlSyclQueue(suai_attrs.queue),
4140
readonly=not suai_attrs.is_writable,
4241
name=None,
4342
aligned=True,

numba_dpex/core/types/dpctl_types.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,18 @@
1414

1515

1616
class DpctlSyclQueue(types.Type):
17-
"""A Numba type to represent a dpctl.SyclQueue PyObject.
18-
19-
For now, a dpctl.SyclQueue is represented as a Numba opaque type that allows
20-
passing in and using a SyclQueue object as an opaque pointer type inside
21-
Numba.
22-
"""
17+
"""A Numba type to represent a dpctl.SyclQueue PyObject."""
2318

2419
def __init__(self, sycl_queue):
2520
if not isinstance(sycl_queue, SyclQueue):
2621
raise TypeError("The argument sycl_queue is not of type SyclQueue.")
2722

28-
self._sycl_queue = sycl_queue
23+
# XXX: Storing the device filter string is a temporary workaround till
24+
# the compute follows data inference pass is fixed to use SyclQueue
25+
self._device = sycl_queue.sycl_device.filter_string
26+
2927
try:
30-
self._unique_id = hash(self._sycl_queue)
28+
self._unique_id = hash(sycl_queue)
3129
except Exception:
3230
self._unique_id = self.rand_digit_str(16)
3331
super(DpctlSyclQueue, self).__init__(name="DpctlSyclQueue")
@@ -38,8 +36,14 @@ def rand_digit_str(self, n):
3836
)
3937

4038
@property
41-
def sycl_queue(self):
42-
return self._sycl_queue
39+
def sycl_device(self):
40+
"""Returns the SYCL oneAPI extension filter string associated with the
41+
queue.
42+
43+
Returns:
44+
str: A SYCL oneAPI extension filter string
45+
"""
46+
return self._device
4347

4448
@property
4549
def key(self):
@@ -69,11 +73,8 @@ def unbox_sycl_queue(typ, obj, c):
6973
qptr = qstruct._getpointer()
7074
ptr = c.builder.bitcast(qptr, c.pyapi.voidptr)
7175

72-
if c.context.enable_nrt:
73-
dpexrtCtx = dpexrt.DpexRTContext(c.context)
74-
errcode = dpexrtCtx.queuestruct_from_python(c.pyapi, obj, ptr)
75-
else:
76-
raise UnreachableError
76+
dpexrtCtx = dpexrt.DpexRTContext(c.context)
77+
errcode = dpexrtCtx.queuestruct_from_python(c.pyapi, obj, ptr)
7778
is_error = cgutils.is_not_null(c.builder, errcode)
7879

7980
# Handle error

numba_dpex/core/types/usm_ndarray_type.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from numba.core.types.npytypes import Array
1313
from numba.np.numpy_support import from_dtype
1414

15+
from numba_dpex.core.types.dpctl_types import DpctlSyclQueue
1516
from numba_dpex.utils import address_space
1617

1718

@@ -31,22 +32,28 @@ def __init__(
3132
aligned=True,
3233
addrspace=address_space.GLOBAL,
3334
):
34-
if queue and not isinstance(queue, types.misc.Omitted) and device:
35+
if (
36+
queue is not None
37+
and not (
38+
isinstance(queue, types.misc.Omitted)
39+
or isinstance(queue, types.misc.NoneType)
40+
)
41+
and device is not None
42+
):
3543
raise TypeError(
3644
"numba_dpex.core.types.usm_ndarray_type.USMNdArray.__init__(): "
37-
"`device` and `sycl_queue` are exclusive keywords, i.e. use one or other."
45+
"`device` and `sycl_queue` are exclusive keywords, "
46+
"i.e. use one or other."
3847
)
3948

40-
self.usm_type = usm_type
41-
self.addrspace = addrspace
42-
43-
if queue and not isinstance(queue, types.misc.Omitted):
44-
if not isinstance(queue, dpctl.SyclQueue):
49+
if queue is not None and not (
50+
isinstance(queue, types.misc.Omitted)
51+
or isinstance(queue, types.misc.NoneType)
52+
):
53+
if not isinstance(queue, DpctlSyclQueue):
4554
raise TypeError(
46-
"numba_dpex.core.types.usm_ndarray_type.USMNdArray.__init__(): "
47-
"The queue keyword arg should be a dpctl.SyclQueue object or None."
48-
"Found type(queue) ="
49-
+ str(type(queue) + " and queue =" + queue)
55+
"The queue keyword arg should be either DpctlSyclQueue or "
56+
"NoneType. Found type(queue) = " + str(type(queue))
5057
)
5158
self.queue = queue
5259
else:
@@ -55,24 +62,23 @@ def __init__(
5562
else:
5663
if not isinstance(device, str):
5764
raise TypeError(
58-
"numba_dpex.core.types.usm_ndarray_type.USMNdArray.__init__(): "
59-
"The device keyword arg should be a str object specifying "
60-
"a SYCL filter selector."
65+
"The device keyword arg should be a str object "
66+
"specifying a SYCL filter selector."
6167
)
6268
sycl_device = dpctl.SyclDevice(device)
6369

64-
self.queue = dpctl._sycl_queue_manager.get_device_cached_queue(
70+
sycl_queue = dpctl._sycl_queue_manager.get_device_cached_queue(
6571
sycl_device
6672
)
73+
self.queue = DpctlSyclQueue(sycl_queue=sycl_queue)
6774

68-
self.device = self.queue.sycl_device.filter_string
75+
self.device = self.queue.sycl_device
76+
self.usm_type = usm_type
77+
self.addrspace = addrspace
6978

7079
if not dtype:
7180
dummy_tensor = dpctl.tensor.empty(
72-
1,
73-
order=layout,
74-
usm_type=usm_type,
75-
sycl_queue=self.queue,
81+
1, order=layout, usm_type=usm_type, device=self.device
7682
)
7783
# convert dpnp type to numba/numpy type
7884
_dtype = dummy_tensor.dtype

numba_dpex/core/typing/typeof.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,18 @@ def _typeof_helper(val, array_class_type):
4242
"The usm_type for the usm_ndarray could not be inferred"
4343
)
4444

45-
assert val.sycl_queue is not None
45+
if not val.sycl_queue:
46+
raise AssertionError
47+
48+
ty_queue = DpctlSyclQueue(sycl_queue=val.sycl_queue)
4649

4750
return array_class_type(
4851
dtype=dtype,
4952
ndim=val.ndim,
5053
layout=layout,
5154
readonly=readonly,
5255
usm_type=usm_type,
53-
queue=val.sycl_queue,
56+
queue=ty_queue,
5457
addrspace=address_space.GLOBAL,
5558
)
5659

0 commit comments

Comments
 (0)