Skip to content

Commit 725070e

Browse files
Diptorup DebZzEeKkAa
andcommitted
Remove kernel_interface.arrayobj
- Merges core.kernel_interface.arrayobj into kernel_api_impl.spirv.arrayobj - Updated populate_array to use USMNdArray type. - Remove core.kernel_api Co-authored-by: Yevhenii Havrylko <[email protected]>
1 parent abb36e1 commit 725070e

File tree

5 files changed

+139
-153
lines changed

5 files changed

+139
-153
lines changed

numba_dpex/core/kernel_interface/__init__.py

Lines changed: 0 additions & 6 deletions
This file was deleted.

numba_dpex/core/kernel_interface/arrayobj.py

Lines changed: 0 additions & 137 deletions
This file was deleted.

numba_dpex/dpnp_iface/arrayobj.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
from numba.np.arrayobj import make_array
1717
from numba.np.numpy_support import is_nonelike
1818

19-
from numba_dpex.core.kernel_interface.arrayobj import (
19+
from numba_dpex.core.types import DpnpNdArray
20+
from numba_dpex.kernel_api_impl.spirv.arrayobj import (
2021
_getitem_array_generic as kernel_getitem_array_generic,
2122
)
22-
from numba_dpex.core.types import DpnpNdArray
2323
from numba_dpex.kernel_api_impl.spirv.target import SPIRVTargetContext
2424

2525
from ._intrinsic import (

numba_dpex/kernel_api_impl/spirv/arrayobj.py

Lines changed: 134 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
# SPDX-FileCopyrightText: 2012 - 2024 Anaconda Inc.
12
# SPDX-FileCopyrightText: 2024 Intel Corporation
23
#
34
# SPDX-License-Identifier: Apache-2.0
5+
# SPDX-License-Identifier: BSD-2-Clause
46

57
"""Contains SPIR-V specific array functions."""
68

@@ -10,7 +12,138 @@
1012
from llvmlite.ir.builder import IRBuilder
1113
from numba.core import cgutils, errors, types
1214
from numba.core.base import BaseContext
13-
from numba.np.arrayobj import get_itemsize
15+
from numba.np.arrayobj import (
16+
basic_indexing,
17+
get_itemsize,
18+
load_item,
19+
make_array,
20+
)
21+
22+
from numba_dpex.core.types import USMNdArray
23+
24+
25+
def populate_array(
26+
arraystruct, data, shape, strides, itemsize
27+
): # pylint: disable=too-many-arguments,too-many-locals
28+
"""
29+
Helper function for populating array structures.
30+
31+
The function is copied from upstream Numba and modified to support the
32+
USMNdArray data type that uses a different data model on SYCL devices
33+
than the upstream types.Array data type. USMNdArray data model does not
34+
have the ``parent`` and ``meminfo`` fields. This function intended to be
35+
used only in the SPIRVKernelTarget.
36+
37+
*shape* and *strides* can be Python tuples or LLVM arrays.
38+
"""
39+
context = arraystruct._context # pylint: disable=protected-access
40+
builder = arraystruct._builder # pylint: disable=protected-access
41+
datamodel = arraystruct._datamodel # pylint: disable=protected-access
42+
# doesn't matter what this array type instance is, it's just to get the
43+
# fields for the data model of the standard array type in this context
44+
standard_array = USMNdArray(ndim=1, layout="C", dtype=types.float64)
45+
standard_array_type_datamodel = context.data_model_manager[standard_array]
46+
required_fields = set(standard_array_type_datamodel._fields)
47+
datamodel_fields = set(datamodel._fields)
48+
# Make sure that the presented array object has a data model that is
49+
# close enough to an array for this function to proceed.
50+
if (required_fields & datamodel_fields) != required_fields:
51+
missing = required_fields - datamodel_fields
52+
msg = (
53+
f"The datamodel for type {arraystruct} is missing "
54+
f"field{'s' if len(missing) > 1 else ''} {missing}."
55+
)
56+
raise ValueError(msg)
57+
58+
intp_t = context.get_value_type(types.intp)
59+
if isinstance(shape, (tuple, list)):
60+
shape = cgutils.pack_array(builder, shape, intp_t)
61+
if isinstance(strides, (tuple, list)):
62+
strides = cgutils.pack_array(builder, strides, intp_t)
63+
if isinstance(itemsize, int):
64+
itemsize = intp_t(itemsize)
65+
66+
attrs = {
67+
"shape": shape,
68+
"strides": strides,
69+
"data": data,
70+
"itemsize": itemsize,
71+
}
72+
73+
# Calc num of items from shape
74+
nitems = context.get_constant(types.intp, 1)
75+
unpacked_shape = cgutils.unpack_tuple(builder, shape, shape.type.count)
76+
# (note empty shape => 0d array therefore nitems = 1)
77+
for axlen in unpacked_shape:
78+
nitems = builder.mul(nitems, axlen, flags=["nsw"])
79+
attrs["nitems"] = nitems
80+
81+
# Make sure that we have all the fields
82+
got_fields = set(attrs.keys())
83+
if got_fields != required_fields:
84+
raise ValueError(f"missing {required_fields - got_fields}")
85+
86+
# Set field value
87+
for k, v in attrs.items():
88+
setattr(arraystruct, k, v)
89+
90+
return arraystruct
91+
92+
93+
def make_view(
94+
context, builder, ary, return_type, data, shapes, strides
95+
): # pylint: disable=too-many-arguments
96+
"""
97+
Build a view over the given array with the given parameters.
98+
99+
This is analog of numpy.np.arrayobj.make_view without parent and
100+
meminfo fields, because they don't make sense on device. This function
101+
intended to be used only in kernel targets.
102+
"""
103+
retary = make_array(return_type)(context, builder)
104+
context.populate_array(
105+
retary, data=data, shape=shapes, strides=strides, itemsize=ary.itemsize
106+
)
107+
return retary
108+
109+
110+
def _getitem_array_generic(
111+
context, builder, return_type, aryty, ary, index_types, indices
112+
): # pylint: disable=too-many-arguments
113+
"""
114+
Return the result of indexing *ary* with the given *indices*,
115+
returning either a scalar or a view.
116+
117+
This is analog of numpy.np.arrayobj._getitem_array_generic without parent
118+
and meminfo fields, because they don't make sense on device. This function
119+
intended to be used only in kernel targets.
120+
"""
121+
dataptr, view_shapes, view_strides = basic_indexing(
122+
context,
123+
builder,
124+
aryty,
125+
ary,
126+
index_types,
127+
indices,
128+
boundscheck=context.enable_boundscheck,
129+
)
130+
131+
if isinstance(return_type, types.Buffer):
132+
# Build array view
133+
retary = make_view(
134+
context,
135+
builder,
136+
ary,
137+
return_type,
138+
dataptr,
139+
view_shapes,
140+
view_strides,
141+
)
142+
return retary._getvalue() # pylint: disable=protected-access
143+
144+
# Load scalar from 0-d result
145+
assert not view_shapes
146+
return load_item(context, builder, aryty, dataptr)
14147

15148

16149
def require_literal(literal_type: types.Type):

numba_dpex/kernel_api_impl/spirv/target.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from numba_dpex.core.types import IntEnumLiteral
2525
from numba_dpex.core.typing import dpnpdecl
2626
from numba_dpex.kernel_api.flag_enum import FlagEnum
27+
from numba_dpex.kernel_api_impl.spirv.arrayobj import populate_array
2728
from numba_dpex.ocl.mathimpl import lower_ocl_impl, sig_mapper
2829
from numba_dpex.utils import address_space, calling_conv
2930

@@ -37,9 +38,7 @@
3738

3839
class CompilationMode(IntEnum):
3940
"""Flags used to determine how a function should be compiled by the
40-
numba_dpex.experimental.dispatcher.KernelDispatcher. Note the functionality
41-
will be merged into numba_dpex.core.kernel_interface.dispatcher in the
42-
future.
41+
numba_dpex.kernel_api_impl_spirv.dispatcher.KernelDispatcher.
4342
4443
KERNEL : Indicates that the function will be compiled into an
4544
LLVM function that has ``spir_kernel`` calling
@@ -419,10 +418,7 @@ def populate_array(self, arr, **kwargs):
419418
"""
420419
Populate array structure.
421420
"""
422-
# pylint: disable=import-outside-toplevel
423-
from numba_dpex.core.kernel_interface import arrayobj
424-
425-
return arrayobj.populate_array(arr, **kwargs)
421+
return populate_array(arr, **kwargs)
426422

427423
def get_executable(self, func, fndesc, env):
428424
"""Not implemented for SPIRVTargetContext"""

0 commit comments

Comments
 (0)