|
6 | 6 |
|
7 | 7 | import dpnp
|
8 | 8 | from numba import errors, types
|
9 |
| -from numba.core.imputils import lower_builtin |
| 9 | +from numba.core.imputils import impl_ret_borrowed, lower_builtin |
10 | 10 | from numba.core.types import scalars
|
11 | 11 | from numba.core.types.containers import UniTuple
|
12 | 12 | from numba.core.typing.npydecl import parse_dtype as _ty_parse_dtype
|
13 | 13 | from numba.core.typing.npydecl import parse_shape as _ty_parse_shape
|
14 | 14 | from numba.extending import overload, overload_attribute
|
15 |
| -from numba.np.arrayobj import getitem_arraynd_intp as np_getitem_arraynd_intp |
| 15 | +from numba.np.arrayobj import _getitem_array_generic as np_getitem_array_generic |
| 16 | +from numba.np.arrayobj import make_array |
16 | 17 | from numba.np.numpy_support import is_nonelike
|
17 | 18 |
|
| 19 | +from numba_dpex.core.kernel_interface.arrayobj import ( |
| 20 | + _getitem_array_generic as kernel_getitem_array_generic, |
| 21 | +) |
| 22 | +from numba_dpex.core.targets.kernel_target import DpexKernelTargetContext |
18 | 23 | from numba_dpex.core.types import DpnpNdArray
|
19 | 24 |
|
20 | 25 | from ._intrinsic import (
|
@@ -1077,9 +1082,24 @@ def getitem_arraynd_intp(context, builder, sig, args):
|
1077 | 1082 | that when returning a view of a dpnp.ndarray the sycl::queue pointer
|
1078 | 1083 | member in the LLVM IR struct gets properly updated.
|
1079 | 1084 | """
|
1080 |
| - ret = np_getitem_arraynd_intp(context, builder, sig, args) |
| 1085 | + getitem_call_in_kernel = isinstance(context, DpexKernelTargetContext) |
| 1086 | + _getitem_array_generic = np_getitem_array_generic |
| 1087 | + |
| 1088 | + if getitem_call_in_kernel: |
| 1089 | + _getitem_array_generic = kernel_getitem_array_generic |
| 1090 | + |
| 1091 | + aryty, idxty = sig.args |
| 1092 | + ary, idx = args |
| 1093 | + |
| 1094 | + assert aryty.ndim >= 1 |
| 1095 | + ary = make_array(aryty)(context, builder, ary) |
| 1096 | + |
| 1097 | + res = _getitem_array_generic( |
| 1098 | + context, builder, sig.return_type, aryty, ary, (idxty,), (idx,) |
| 1099 | + ) |
| 1100 | + ret = impl_ret_borrowed(context, builder, sig.return_type, res) |
1081 | 1101 |
|
1082 |
| - if isinstance(sig.return_type, DpnpNdArray): |
| 1102 | + if isinstance(sig.return_type, DpnpNdArray) and not getitem_call_in_kernel: |
1083 | 1103 | array_val = args[0]
|
1084 | 1104 | array_ty = sig.args[0]
|
1085 | 1105 | sycl_queue_attr_pos = context.data_model_manager.lookup(
|
|
0 commit comments