Skip to content

Commit ebc6c32

Browse files
author
Diptorup Deb
authored
Merge pull request #1419 from IntelPython/remove/core.kernel_interface
Refactors and removes core.kernel interface
2 parents 0a21ec2 + 725070e commit ebc6c32

File tree

9 files changed

+164
-204
lines changed

9 files changed

+164
-204
lines changed

numba_dpex/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def parse_sem_version(version_string: str) -> Tuple[int, int, int]:
8484
import numba_dpex.core.types as types # noqa E402
8585
from numba_dpex.core import boxing # noqa E402
8686
from numba_dpex.core import config # noqa E402
87-
from numba_dpex.core.kernel_interface import ranges_overloads # noqa E402
87+
from numba_dpex.core.overloads import ranges_overloads # noqa E402
8888

8989
# Re-export all type names
9090
from numba_dpex.core.types import * # noqa E402

numba_dpex/core/kernel_interface/arrayobj.py

Lines changed: 0 additions & 137 deletions
This file was deleted.
Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
11
# SPDX-FileCopyrightText: 2020 - 2024 Intel Corporation
22
#
33
# SPDX-License-Identifier: Apache-2.0
4-
5-
"""Defines the interface for kernel compilation using numba-dpex.
6-
"""

numba_dpex/dpnp_iface/arrayobj.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
from numba.np.arrayobj import make_array
1717
from numba.np.numpy_support import is_nonelike
1818

19-
from numba_dpex.core.kernel_interface.arrayobj import (
19+
from numba_dpex.core.types import DpnpNdArray
20+
from numba_dpex.kernel_api_impl.spirv.arrayobj import (
2021
_getitem_array_generic as kernel_getitem_array_generic,
2122
)
22-
from numba_dpex.core.types import DpnpNdArray
2323
from numba_dpex.kernel_api_impl.spirv.target import SPIRVTargetContext
2424

2525
from ._intrinsic import (

numba_dpex/kernel_api_impl/spirv/arrayobj.py

Lines changed: 139 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,149 @@
1+
# SPDX-FileCopyrightText: 2012 - 2024 Anaconda Inc.
12
# SPDX-FileCopyrightText: 2024 Intel Corporation
23
#
34
# SPDX-License-Identifier: Apache-2.0
5+
# SPDX-License-Identifier: BSD-2-Clause
46

57
"""Contains SPIR-V specific array functions."""
68

7-
import operator
8-
from functools import reduce
99
from typing import Union
1010

1111
import llvmlite.ir as llvmir
1212
from llvmlite.ir.builder import IRBuilder
1313
from numba.core import cgutils, errors, types
1414
from numba.core.base import BaseContext
15+
from numba.np.arrayobj import (
16+
basic_indexing,
17+
get_itemsize,
18+
load_item,
19+
make_array,
20+
)
1521

16-
from numba_dpex.kernel_api_impl.spirv.target import SPIRVTargetContext
17-
from numba_dpex.ocl.oclimpl import _get_target_data
22+
from numba_dpex.core.types import USMNdArray
1823

1924

20-
def get_itemsize(context: SPIRVTargetContext, array_type: types.Array):
25+
def populate_array(
26+
arraystruct, data, shape, strides, itemsize
27+
): # pylint: disable=too-many-arguments,too-many-locals
2128
"""
22-
Return the item size for the given array or buffer type.
23-
Same as numba.np.arrayobj.get_itemsize, but using spirv data.
29+
Helper function for populating array structures.
30+
31+
The function is copied from upstream Numba and modified to support the
32+
USMNdArray data type that uses a different data model on SYCL devices
33+
than the upstream types.Array data type. USMNdArray data model does not
34+
have the ``parent`` and ``meminfo`` fields. This function intended to be
35+
used only in the SPIRVKernelTarget.
36+
37+
*shape* and *strides* can be Python tuples or LLVM arrays.
38+
"""
39+
context = arraystruct._context # pylint: disable=protected-access
40+
builder = arraystruct._builder # pylint: disable=protected-access
41+
datamodel = arraystruct._datamodel # pylint: disable=protected-access
42+
# doesn't matter what this array type instance is, it's just to get the
43+
# fields for the data model of the standard array type in this context
44+
standard_array = USMNdArray(ndim=1, layout="C", dtype=types.float64)
45+
standard_array_type_datamodel = context.data_model_manager[standard_array]
46+
required_fields = set(standard_array_type_datamodel._fields)
47+
datamodel_fields = set(datamodel._fields)
48+
# Make sure that the presented array object has a data model that is
49+
# close enough to an array for this function to proceed.
50+
if (required_fields & datamodel_fields) != required_fields:
51+
missing = required_fields - datamodel_fields
52+
msg = (
53+
f"The datamodel for type {arraystruct} is missing "
54+
f"field{'s' if len(missing) > 1 else ''} {missing}."
55+
)
56+
raise ValueError(msg)
57+
58+
intp_t = context.get_value_type(types.intp)
59+
if isinstance(shape, (tuple, list)):
60+
shape = cgutils.pack_array(builder, shape, intp_t)
61+
if isinstance(strides, (tuple, list)):
62+
strides = cgutils.pack_array(builder, strides, intp_t)
63+
if isinstance(itemsize, int):
64+
itemsize = intp_t(itemsize)
65+
66+
attrs = {
67+
"shape": shape,
68+
"strides": strides,
69+
"data": data,
70+
"itemsize": itemsize,
71+
}
72+
73+
# Calc num of items from shape
74+
nitems = context.get_constant(types.intp, 1)
75+
unpacked_shape = cgutils.unpack_tuple(builder, shape, shape.type.count)
76+
# (note empty shape => 0d array therefore nitems = 1)
77+
for axlen in unpacked_shape:
78+
nitems = builder.mul(nitems, axlen, flags=["nsw"])
79+
attrs["nitems"] = nitems
80+
81+
# Make sure that we have all the fields
82+
got_fields = set(attrs.keys())
83+
if got_fields != required_fields:
84+
raise ValueError(f"missing {required_fields - got_fields}")
85+
86+
# Set field value
87+
for k, v in attrs.items():
88+
setattr(arraystruct, k, v)
89+
90+
return arraystruct
91+
92+
93+
def make_view(
94+
context, builder, ary, return_type, data, shapes, strides
95+
): # pylint: disable=too-many-arguments
96+
"""
97+
Build a view over the given array with the given parameters.
98+
99+
This is analog of numpy.np.arrayobj.make_view without parent and
100+
meminfo fields, because they don't make sense on device. This function
101+
intended to be used only in kernel targets.
102+
"""
103+
retary = make_array(return_type)(context, builder)
104+
context.populate_array(
105+
retary, data=data, shape=shapes, strides=strides, itemsize=ary.itemsize
106+
)
107+
return retary
108+
109+
110+
def _getitem_array_generic(
111+
context, builder, return_type, aryty, ary, index_types, indices
112+
): # pylint: disable=too-many-arguments
24113
"""
25-
targetdata = _get_target_data(context)
26-
lldtype = context.get_data_type(array_type.dtype)
27-
return lldtype.get_abi_size(targetdata)
114+
Return the result of indexing *ary* with the given *indices*,
115+
returning either a scalar or a view.
116+
117+
This is analog of numpy.np.arrayobj._getitem_array_generic without parent
118+
and meminfo fields, because they don't make sense on device. This function
119+
intended to be used only in kernel targets.
120+
"""
121+
dataptr, view_shapes, view_strides = basic_indexing(
122+
context,
123+
builder,
124+
aryty,
125+
ary,
126+
index_types,
127+
indices,
128+
boundscheck=context.enable_boundscheck,
129+
)
130+
131+
if isinstance(return_type, types.Buffer):
132+
# Build array view
133+
retary = make_view(
134+
context,
135+
builder,
136+
ary,
137+
return_type,
138+
dataptr,
139+
view_shapes,
140+
view_strides,
141+
)
142+
return retary._getvalue() # pylint: disable=protected-access
143+
144+
# Load scalar from 0-d result
145+
assert not view_shapes
146+
return load_item(context, builder, aryty, dataptr)
28147

29148

30149
def require_literal(literal_type: types.Type):
@@ -46,15 +165,22 @@ def require_literal(literal_type: types.Type):
46165
)
47166

48167

49-
def make_spirv_array( # pylint: disable=too-many-arguments
50-
context: SPIRVTargetContext,
168+
def np_cfarray( # pylint: disable=too-many-arguments
169+
context: BaseContext,
51170
builder: IRBuilder,
52171
ty_array: types.Array,
53172
ty_shape: Union[types.IntegerLiteral, types.BaseTuple],
54173
shape: llvmir.Value,
55174
data: llvmir.Value,
56175
):
57-
"""Makes SPIR-V array and fills it data."""
176+
"""Makes numpy-like array and fills it's data depending on the context's
177+
datamodel.
178+
179+
Generic version of numba.np.arrayobj.np_cfarray so that it can be used
180+
not only as intrinsic, but inside instruction generation.
181+
182+
TODO: upstream changes to numba.
183+
"""
58184
# Create array object
59185
ary = context.make_array(ty_array)(context, builder)
60186

@@ -92,32 +218,3 @@ def make_spirv_array( # pylint: disable=too-many-arguments
92218
)
93219

94220
return ary
95-
96-
97-
def allocate_array_data_on_stack(
98-
context: BaseContext,
99-
builder: IRBuilder,
100-
ty_array: types.Array,
101-
ty_shape: Union[types.IntegerLiteral, types.BaseTuple],
102-
):
103-
"""Allocates flat array of given shape on the stack."""
104-
if not isinstance(ty_shape, types.BaseTuple):
105-
ty_shape = (ty_shape,)
106-
107-
return cgutils.alloca_once(
108-
builder,
109-
context.get_data_type(ty_array.dtype),
110-
size=reduce(operator.mul, [s.literal_value for s in ty_shape]),
111-
)
112-
113-
114-
def make_spirv_generic_array_on_stack(
115-
context: SPIRVTargetContext,
116-
builder: IRBuilder,
117-
ty_array: types.Array,
118-
ty_shape: Union[types.IntegerLiteral, types.BaseTuple],
119-
shape: llvmir.Value,
120-
):
121-
"""Makes SPIR-V array of given shape with empty data."""
122-
data = allocate_array_data_on_stack(context, builder, ty_array, ty_shape)
123-
return make_spirv_array(context, builder, ty_array, ty_shape, shape, data)

0 commit comments

Comments
 (0)