Skip to content

Commit ab22e21

Browse files
committed
Overload dpnp array's sycl_queue attribute
1 parent cbbe5e0 commit ab22e21

File tree

2 files changed

+82
-2
lines changed

2 files changed

+82
-2
lines changed

numba_dpex/dpnp_iface/_intrinsic.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66

77
from dpctl import get_device_cached_queue
88
from llvmlite import ir as llvmir
9-
from llvmlite.ir import Constant
9+
from llvmlite.ir import Constant, IRBuilder
1010
from llvmlite.ir.types import DoubleType, FloatType
1111
from numba import types
1212
from numba.core import cgutils
1313
from numba.core import config as numba_config
14+
from numba.core import errors, imputils
1415
from numba.core.typing import signature
1516
from numba.extending import intrinsic, overload_classmethod
1617
from numba.np.arrayobj import (
@@ -1076,3 +1077,61 @@ def codegen(context, builder, sig, args):
10761077
return ary._getvalue()
10771078

10781079
return signature, codegen
1080+
1081+
1082+
@intrinsic
1083+
def ol_dpnp_nd_array_sycl_queue(
1084+
ty_context,
1085+
ty_dpnp_nd_array: DpnpNdArray,
1086+
):
1087+
if not isinstance(ty_dpnp_nd_array, DpnpNdArray):
1088+
raise errors.TypingError("Argument must be DpnpNdArray")
1089+
1090+
ty_queue: DpctlSyclQueue = ty_dpnp_nd_array.queue
1091+
1092+
sig = ty_queue(ty_dpnp_nd_array)
1093+
1094+
def codegen(context, builder: IRBuilder, sig, args: list):
1095+
array_proxy = cgutils.create_struct_proxy(ty_dpnp_nd_array)(
1096+
context,
1097+
builder,
1098+
value=args[0],
1099+
)
1100+
1101+
queue_ref = array_proxy.sycl_queue
1102+
1103+
queue_struct_proxy = cgutils.create_struct_proxy(ty_queue)(
1104+
context, builder
1105+
)
1106+
1107+
queue_struct_proxy.queue_ref = queue_ref
1108+
queue_struct_proxy.meminfo = array_proxy.meminfo
1109+
1110+
# Warning: current implementation prevents whole object from being
1111+
# destroyed as long as sycl_queue attribute is being used. It should be
1112+
# okay since anywere we use it as an argument callee creates a copy
1113+
# so it does not steel reference.
1114+
#
1115+
# We can avoid it by:
1116+
# queue_ref_copy = sycl.dpctl_queue_copy(builder, queue_ref) #noqa E800
1117+
# queue_struct_proxy.queue_ref = queue_ref_copy #noqa E800
1118+
# queue_struct->meminfo =
1119+
# nrt->manage_memory(queue_ref_copy, DPCTLEvent_Delete);
1120+
# but it will allocate new meminfo object which can negatively affect
1121+
# performance.
1122+
# Speaking philosophically attribute is a part of the object and as long
1123+
# as nobody can still the reference it is a part of the owner object
1124+
# and lifetime is tied to it.
1125+
# TODO: we want to have queue: queuestruct_t instead of
1126+
# queue_ref: QueueRef as an attribute for DPNPNdArray.
1127+
1128+
queue_value = queue_struct_proxy._getvalue()
1129+
1130+
# We need to incref meminfo so that queue model is preventing parent
1131+
# ndarray from being destroyed, that can destroy queue that we are
1132+
# using.
1133+
return imputils.impl_ret_borrowed(
1134+
context, builder, ty_queue, queue_value
1135+
)
1136+
1137+
return sig, codegen

numba_dpex/dpnp_iface/arrayobj.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from numba.core.types.containers import UniTuple
1212
from numba.core.typing.npydecl import parse_dtype as _ty_parse_dtype
1313
from numba.core.typing.npydecl import parse_shape as _ty_parse_shape
14-
from numba.extending import overload
14+
from numba.extending import overload, overload_attribute
1515
from numba.np.arrayobj import getitem_arraynd_intp as np_getitem_arraynd_intp
1616
from numba.np.numpy_support import is_nonelike
1717

@@ -27,6 +27,7 @@
2727
impl_dpnp_ones_like,
2828
impl_dpnp_zeros,
2929
impl_dpnp_zeros_like,
30+
ol_dpnp_nd_array_sycl_queue,
3031
)
3132

3233
# =========================================================================
@@ -1085,3 +1086,23 @@ def getitem_arraynd_intp(context, builder, sig, args):
10851086
ret = builder.insert_value(ret, sycl_queue_attr, sycl_queue_attr_pos)
10861087

10871088
return ret
1089+
1090+
1091+
@overload_attribute(DpnpNdArray, "sycl_queue")
1092+
def dpnp_nd_array_sycl_queue(arr):
1093+
"""Returns :class:`dpctl.SyclQueue` object associated with USM data.
1094+
1095+
This is an overloaded attribute implementation for dpnp.sycl_queue.
1096+
1097+
Args:
1098+
arr (numba_dpex.core.types.DpnpNdArray): Input array from which to
1099+
take sycl_queue.
1100+
1101+
Returns:
1102+
function: Local function `ol_dpnp_nd_array_sycl_queue()`.
1103+
"""
1104+
1105+
def get(arr):
1106+
return ol_dpnp_nd_array_sycl_queue(arr)
1107+
1108+
return get

0 commit comments

Comments
 (0)