Skip to content

Commit 3956fe7

Browse files
committed
Move getitem implementation to usm from dpnp
1 parent 8f04945 commit 3956fe7

File tree

5 files changed

+215
-173
lines changed

5 files changed

+215
-173
lines changed

numba_dpex/dpctl_iface/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,5 @@
66
The ``dpctl_iface`` module implements Numba's interface to the libsyclinterface
77
library that provides C bindings to DPC++'s SYCL runtime API.
88
"""
9+
10+
from . import arrayobj

numba_dpex/dpctl_iface/arrayobj.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# SPDX-FileCopyrightText: 2020 - 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import operator
6+
7+
from llvmlite.ir import IRBuilder
8+
from numba.core import cgutils, errors, imputils, types
9+
from numba.core.imputils import impl_ret_borrowed
10+
from numba.extending import intrinsic, overload_attribute
11+
from numba.np.arrayobj import _getitem_array_generic as np_getitem_array_generic
12+
from numba.np.arrayobj import make_array
13+
14+
from numba_dpex.core.types import DpnpNdArray, USMNdArray
15+
from numba_dpex.core.types.dpctl_types import DpctlSyclQueue
16+
from numba_dpex.kernel_api_impl.spirv.arrayobj import (
17+
_getitem_array_generic as kernel_getitem_array_generic,
18+
)
19+
from numba_dpex.kernel_api_impl.spirv.target import SPIRVTargetContext
20+
21+
from .dpctlimpl import lower_builtin
22+
23+
# can't import name because of the circular import
24+
DPEX_TARGET_NAME = "dpex"
25+
26+
# =========================================================================
27+
# Helps to parse dpnp constructor arguments
28+
# =========================================================================
29+
30+
31+
# TODO: target specific
32+
@lower_builtin(operator.getitem, USMNdArray, types.Integer)
33+
@lower_builtin(operator.getitem, USMNdArray, types.SliceType)
34+
def getitem_arraynd_intp(context, builder, sig, args):
35+
"""
36+
Overrding the numba.np.arrayobj.getitem_arraynd_intp to support dpnp.ndarray
37+
38+
The data model for numba.types.Array and numba_dpex.types.DpnpNdArray
39+
are different. DpnpNdArray has an extra attribute to store a sycl::queue
40+
pointer. For that reason, np_getitem_arraynd_intp needs to be overriden so
41+
that when returning a view of a dpnp.ndarray the sycl::queue pointer
42+
member in the LLVM IR struct gets properly updated.
43+
"""
44+
getitem_call_in_kernel = isinstance(context, SPIRVTargetContext)
45+
_getitem_array_generic = np_getitem_array_generic
46+
47+
if getitem_call_in_kernel:
48+
_getitem_array_generic = kernel_getitem_array_generic
49+
50+
aryty, idxty = sig.args
51+
ary, idx = args
52+
53+
assert aryty.ndim >= 1
54+
ary = make_array(aryty)(context, builder, ary)
55+
56+
res = _getitem_array_generic(
57+
context, builder, sig.return_type, aryty, ary, (idxty,), (idx,)
58+
)
59+
ret = impl_ret_borrowed(context, builder, sig.return_type, res)
60+
61+
if isinstance(sig.return_type, USMNdArray) and not getitem_call_in_kernel:
62+
array_val = args[0]
63+
array_ty = sig.args[0]
64+
sycl_queue_attr_pos = context.data_model_manager.lookup(
65+
array_ty
66+
).get_field_position("sycl_queue")
67+
sycl_queue_attr = builder.extract_value(array_val, sycl_queue_attr_pos)
68+
ret = builder.insert_value(ret, sycl_queue_attr, sycl_queue_attr_pos)
69+
70+
return ret
71+
72+
73+
@intrinsic(target=DPEX_TARGET_NAME)
74+
def ol_usm_nd_array_sycl_queue(
75+
ty_context,
76+
ty_dpnp_nd_array: DpnpNdArray,
77+
):
78+
if not isinstance(ty_dpnp_nd_array, DpnpNdArray):
79+
raise errors.TypingError("Argument must be DpnpNdArray")
80+
81+
ty_queue: DpctlSyclQueue = ty_dpnp_nd_array.queue
82+
83+
sig = ty_queue(ty_dpnp_nd_array)
84+
85+
def codegen(context, builder: IRBuilder, sig, args: list):
86+
array_proxy = cgutils.create_struct_proxy(ty_dpnp_nd_array)(
87+
context,
88+
builder,
89+
value=args[0],
90+
)
91+
92+
queue_ref = array_proxy.sycl_queue
93+
94+
queue_struct_proxy = cgutils.create_struct_proxy(ty_queue)(
95+
context, builder
96+
)
97+
98+
queue_struct_proxy.queue_ref = queue_ref
99+
queue_struct_proxy.meminfo = array_proxy.meminfo
100+
101+
# Warning: current implementation prevents whole object from being
102+
# destroyed as long as sycl_queue attribute is being used. It should be
103+
# okay since anywere we use it as an argument callee creates a copy
104+
# so it does not steel reference.
105+
#
106+
# We can avoid it by:
107+
# queue_ref_copy = sycl.dpctl_queue_copy(builder, queue_ref) #noqa E800
108+
# queue_struct_proxy.queue_ref = queue_ref_copy #noqa E800
109+
# queue_struct->meminfo =
110+
# nrt->manage_memory(queue_ref_copy, DPCTLEvent_Delete);
111+
# but it will allocate new meminfo object which can negatively affect
112+
# performance.
113+
# Speaking philosophically attribute is a part of the object and as long
114+
# as nobody can still the reference it is a part of the owner object
115+
# and lifetime is tied to it.
116+
# TODO: we want to have queue: queuestruct_t instead of
117+
# queue_ref: QueueRef as an attribute for DPNPNdArray.
118+
119+
queue_value = queue_struct_proxy._getvalue()
120+
121+
# We need to incref meminfo so that queue model is preventing parent
122+
# ndarray from being destroyed, that can destroy queue that we are
123+
# using.
124+
return imputils.impl_ret_borrowed(
125+
context, builder, ty_queue, queue_value
126+
)
127+
128+
return sig, codegen
129+
130+
131+
@overload_attribute(USMNdArray, "sycl_queue", target=DPEX_TARGET_NAME)
132+
def dpnp_nd_array_sycl_queue(arr):
133+
"""Returns :class:`dpctl.SyclQueue` object associated with USM data.
134+
135+
This is an overloaded attribute implementation for dpnp.sycl_queue.
136+
137+
Args:
138+
arr (numba_dpex.core.types.DpnpNdArray): Input array from which to
139+
take sycl_queue.
140+
141+
Returns:
142+
function: Local function `ol_dpnp_nd_array_sycl_queue()`.
143+
"""
144+
145+
def get(arr):
146+
return ol_usm_nd_array_sycl_queue(arr)
147+
148+
return get

numba_dpex/dpnp_iface/_intrinsic.py

Lines changed: 1 addition & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,11 @@
66

77
from dpctl import get_device_cached_queue
88
from llvmlite import ir as llvmir
9-
from llvmlite.ir import Constant, IRBuilder
9+
from llvmlite.ir import Constant
1010
from llvmlite.ir.types import DoubleType, FloatType
1111
from numba import types
1212
from numba.core import cgutils
1313
from numba.core import config as numba_config
14-
from numba.core import errors, imputils
1514
from numba.core.typing import signature
1615
from numba.extending import intrinsic, overload_classmethod
1716
from numba.np.arrayobj import (
@@ -1081,61 +1080,3 @@ def codegen(context, builder, sig, args):
10811080
return ary._getvalue()
10821081

10831082
return signature, codegen
1084-
1085-
1086-
@intrinsic(target=DPEX_TARGET_NAME)
1087-
def ol_dpnp_nd_array_sycl_queue(
1088-
ty_context,
1089-
ty_dpnp_nd_array: DpnpNdArray,
1090-
):
1091-
if not isinstance(ty_dpnp_nd_array, DpnpNdArray):
1092-
raise errors.TypingError("Argument must be DpnpNdArray")
1093-
1094-
ty_queue: DpctlSyclQueue = ty_dpnp_nd_array.queue
1095-
1096-
sig = ty_queue(ty_dpnp_nd_array)
1097-
1098-
def codegen(context, builder: IRBuilder, sig, args: list):
1099-
array_proxy = cgutils.create_struct_proxy(ty_dpnp_nd_array)(
1100-
context,
1101-
builder,
1102-
value=args[0],
1103-
)
1104-
1105-
queue_ref = array_proxy.sycl_queue
1106-
1107-
queue_struct_proxy = cgutils.create_struct_proxy(ty_queue)(
1108-
context, builder
1109-
)
1110-
1111-
queue_struct_proxy.queue_ref = queue_ref
1112-
queue_struct_proxy.meminfo = array_proxy.meminfo
1113-
1114-
# Warning: current implementation prevents whole object from being
1115-
# destroyed as long as sycl_queue attribute is being used. It should be
1116-
# okay since anywere we use it as an argument callee creates a copy
1117-
# so it does not steel reference.
1118-
#
1119-
# We can avoid it by:
1120-
# queue_ref_copy = sycl.dpctl_queue_copy(builder, queue_ref) #noqa E800
1121-
# queue_struct_proxy.queue_ref = queue_ref_copy #noqa E800
1122-
# queue_struct->meminfo =
1123-
# nrt->manage_memory(queue_ref_copy, DPCTLEvent_Delete);
1124-
# but it will allocate new meminfo object which can negatively affect
1125-
# performance.
1126-
# Speaking philosophically attribute is a part of the object and as long
1127-
# as nobody can still the reference it is a part of the owner object
1128-
# and lifetime is tied to it.
1129-
# TODO: we want to have queue: queuestruct_t instead of
1130-
# queue_ref: QueueRef as an attribute for DPNPNdArray.
1131-
1132-
queue_value = queue_struct_proxy._getvalue()
1133-
1134-
# We need to incref meminfo so that queue model is preventing parent
1135-
# ndarray from being destroyed, that can destroy queue that we are
1136-
# using.
1137-
return imputils.impl_ret_borrowed(
1138-
context, builder, ty_queue, queue_value
1139-
)
1140-
1141-
return sig, codegen

numba_dpex/dpnp_iface/arrayobj.py

Lines changed: 1 addition & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,16 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
import operator
6-
75
import dpnp
86
from numba import errors, types
9-
from numba.core.imputils import impl_ret_borrowed, lower_builtin
107
from numba.core.types import scalars
118
from numba.core.types.containers import UniTuple
129
from numba.core.typing.npydecl import parse_dtype as _ty_parse_dtype
1310
from numba.core.typing.npydecl import parse_shape as _ty_parse_shape
14-
from numba.extending import overload, overload_attribute
15-
from numba.np.arrayobj import _getitem_array_generic as np_getitem_array_generic
16-
from numba.np.arrayobj import make_array
11+
from numba.extending import overload
1712
from numba.np.numpy_support import is_nonelike
1813

1914
from numba_dpex.core.types import DpnpNdArray
20-
from numba_dpex.kernel_api_impl.spirv.arrayobj import (
21-
_getitem_array_generic as kernel_getitem_array_generic,
22-
)
23-
from numba_dpex.kernel_api_impl.spirv.target import SPIRVTargetContext
2415

2516
from ._intrinsic import (
2617
impl_dpnp_empty,
@@ -31,7 +22,6 @@
3122
impl_dpnp_ones_like,
3223
impl_dpnp_zeros,
3324
impl_dpnp_zeros_like,
34-
ol_dpnp_nd_array_sycl_queue,
3525
)
3626

3727
# can't import name because of the circular import
@@ -1067,65 +1057,3 @@ def impl(
10671057
"Cannot parse input types to "
10681058
+ f"function dpnp.full_like({x1}, {fill_value}, {dtype}, ...)."
10691059
)
1070-
1071-
1072-
# TODO: target specific
1073-
@lower_builtin(operator.getitem, DpnpNdArray, types.Integer)
1074-
@lower_builtin(operator.getitem, DpnpNdArray, types.SliceType)
1075-
def getitem_arraynd_intp(context, builder, sig, args):
1076-
"""
1077-
Overrding the numba.np.arrayobj.getitem_arraynd_intp to support dpnp.ndarray
1078-
1079-
The data model for numba.types.Array and numba_dpex.types.DpnpNdArray
1080-
are different. DpnpNdArray has an extra attribute to store a sycl::queue
1081-
pointer. For that reason, np_getitem_arraynd_intp needs to be overriden so
1082-
that when returning a view of a dpnp.ndarray the sycl::queue pointer
1083-
member in the LLVM IR struct gets properly updated.
1084-
"""
1085-
getitem_call_in_kernel = isinstance(context, SPIRVTargetContext)
1086-
_getitem_array_generic = np_getitem_array_generic
1087-
1088-
if getitem_call_in_kernel:
1089-
_getitem_array_generic = kernel_getitem_array_generic
1090-
1091-
aryty, idxty = sig.args
1092-
ary, idx = args
1093-
1094-
assert aryty.ndim >= 1
1095-
ary = make_array(aryty)(context, builder, ary)
1096-
1097-
res = _getitem_array_generic(
1098-
context, builder, sig.return_type, aryty, ary, (idxty,), (idx,)
1099-
)
1100-
ret = impl_ret_borrowed(context, builder, sig.return_type, res)
1101-
1102-
if isinstance(sig.return_type, DpnpNdArray) and not getitem_call_in_kernel:
1103-
array_val = args[0]
1104-
array_ty = sig.args[0]
1105-
sycl_queue_attr_pos = context.data_model_manager.lookup(
1106-
array_ty
1107-
).get_field_position("sycl_queue")
1108-
sycl_queue_attr = builder.extract_value(array_val, sycl_queue_attr_pos)
1109-
ret = builder.insert_value(ret, sycl_queue_attr, sycl_queue_attr_pos)
1110-
1111-
return ret
1112-
1113-
1114-
@overload_attribute(DpnpNdArray, "sycl_queue", target=DPEX_TARGET_NAME)
1115-
def dpnp_nd_array_sycl_queue(arr):
1116-
"""Returns :class:`dpctl.SyclQueue` object associated with USM data.
1117-
1118-
This is an overloaded attribute implementation for dpnp.sycl_queue.
1119-
1120-
Args:
1121-
arr (numba_dpex.core.types.DpnpNdArray): Input array from which to
1122-
take sycl_queue.
1123-
1124-
Returns:
1125-
function: Local function `ol_dpnp_nd_array_sycl_queue()`.
1126-
"""
1127-
1128-
def get(arr):
1129-
return ol_dpnp_nd_array_sycl_queue(arr)
1130-
1131-
return get

0 commit comments

Comments
 (0)