Skip to content

Commit 8c3a7d5

Browse files
committed
Add support for kernel specific array views
1 parent 753ba3a commit 8c3a7d5

File tree

2 files changed

+80
-4
lines changed

2 files changed

+80
-4
lines changed

numba_dpex/core/kernel_interface/arrayobj.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313

1414
from numba.core import cgutils, types
15+
from numba.np import arrayobj as np_arrayobj
1516

1617
from numba_dpex.core import types as dpex_types
1718

@@ -79,3 +80,58 @@ def populate_array(array, data, shape, strides, itemsize):
7980
setattr(array, k, v)
8081

8182
return array
83+
84+
85+
def make_view(context, builder, aryty, ary, return_type, data, shapes, strides):
86+
"""
87+
Build a view over the given array with the given parameters.
88+
89+
This is analog of numpy.np.arrayobj.make_view without parent and
90+
meminfo fields, because they don't make sense on device. This function
91+
intended to be used only in kernel targets.
92+
"""
93+
retary = np_arrayobj.make_array(return_type)(context, builder)
94+
populate_array(
95+
retary, data=data, shape=shapes, strides=strides, itemsize=ary.itemsize
96+
)
97+
return retary
98+
99+
100+
def _getitem_array_generic(
101+
context, builder, return_type, aryty, ary, index_types, indices
102+
):
103+
"""
104+
Return the result of indexing *ary* with the given *indices*,
105+
returning either a scalar or a view.
106+
107+
This is analog of numpy.np.arrayobj._getitem_array_generic without parent
108+
and meminfo fields, because they don't make sense on device. This function
109+
intended to be used only in kernel targets.
110+
"""
111+
dataptr, view_shapes, view_strides = np_arrayobj.basic_indexing(
112+
context,
113+
builder,
114+
aryty,
115+
ary,
116+
index_types,
117+
indices,
118+
boundscheck=context.enable_boundscheck,
119+
)
120+
121+
if isinstance(return_type, types.Buffer):
122+
# Build array view
123+
retary = make_view(
124+
context,
125+
builder,
126+
aryty,
127+
ary,
128+
return_type,
129+
dataptr,
130+
view_shapes,
131+
view_strides,
132+
)
133+
return retary._getvalue()
134+
else:
135+
# Load scalar from 0-d result
136+
assert not view_shapes
137+
return np_arrayobj.load_item(context, builder, aryty, dataptr)

numba_dpex/dpnp_iface/arrayobj.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,20 @@
66

77
import dpnp
88
from numba import errors, types
9-
from numba.core.imputils import lower_builtin
9+
from numba.core.imputils import impl_ret_borrowed, lower_builtin
1010
from numba.core.types import scalars
1111
from numba.core.types.containers import UniTuple
1212
from numba.core.typing.npydecl import parse_dtype as _ty_parse_dtype
1313
from numba.core.typing.npydecl import parse_shape as _ty_parse_shape
1414
from numba.extending import overload, overload_attribute
15-
from numba.np.arrayobj import getitem_arraynd_intp as np_getitem_arraynd_intp
15+
from numba.np.arrayobj import _getitem_array_generic as np_getitem_array_generic
16+
from numba.np.arrayobj import make_array
1617
from numba.np.numpy_support import is_nonelike
1718

19+
from numba_dpex.core.kernel_interface.arrayobj import (
20+
_getitem_array_generic as kernel_getitem_array_generic,
21+
)
22+
from numba_dpex.core.targets.kernel_target import DpexKernelTargetContext
1823
from numba_dpex.core.types import DpnpNdArray
1924

2025
from ._intrinsic import (
@@ -1077,9 +1082,24 @@ def getitem_arraynd_intp(context, builder, sig, args):
10771082
that when returning a view of a dpnp.ndarray the sycl::queue pointer
10781083
member in the LLVM IR struct gets properly updated.
10791084
"""
1080-
ret = np_getitem_arraynd_intp(context, builder, sig, args)
1085+
getitem_call_in_kernel = isinstance(context, DpexKernelTargetContext)
1086+
_getitem_array_generic = np_getitem_array_generic
1087+
1088+
if getitem_call_in_kernel:
1089+
_getitem_array_generic = kernel_getitem_array_generic
1090+
1091+
aryty, idxty = sig.args
1092+
ary, idx = args
1093+
1094+
assert aryty.ndim >= 1
1095+
ary = make_array(aryty)(context, builder, ary)
1096+
1097+
res = _getitem_array_generic(
1098+
context, builder, sig.return_type, aryty, ary, (idxty,), (idx,)
1099+
)
1100+
ret = impl_ret_borrowed(context, builder, sig.return_type, res)
10811101

1082-
if isinstance(sig.return_type, DpnpNdArray):
1102+
if isinstance(sig.return_type, DpnpNdArray) and not getitem_call_in_kernel:
10831103
array_val = args[0]
10841104
array_ty = sig.args[0]
10851105
sycl_queue_attr_pos = context.data_model_manager.lookup(

0 commit comments

Comments
 (0)