Skip to content

Commit 1854eb6

Browse files
author
Diptorup Deb
authored
Merge pull request #1306 from IntelPython/refactor/ranges
Add range classes to kernel_api
2 parents aeb844a + 8ede9dc commit 1854eb6

File tree

21 files changed

+406
-362
lines changed

21 files changed

+406
-362
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ repos:
5252
- id: pylint
5353
name: pylint
5454
entry: pylint
55-
files: ^numba_dpex/experimental|^numba_dpex/core/utils/kernel_launcher.py
55+
files: ^numba_dpex/kernel_api|^numba_dpex/experimental|^numba_dpex/core/utils/kernel_launcher.py
5656
language: system
5757
types: [python]
5858
require_serial: true

numba_dpex/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,17 +98,19 @@ def parse_sem_version(version_string: str) -> Tuple[int, int, int]:
9898

9999
# Re-export types itself
100100
import numba_dpex.core.types as types # noqa E402
101+
from numba_dpex.core import boxing # noqa E402
101102
from numba_dpex.core import config # noqa E402
102-
from numba_dpex.core.kernel_interface.indexers import ( # noqa E402
103-
NdRange,
104-
Range,
105-
)
103+
from numba_dpex.core.kernel_interface import ranges_overloads # noqa E402
106104

107105
# Re-export all type names
108106
from numba_dpex.core.types import * # noqa E402
109107
from numba_dpex.dpctl_iface import _intrinsic # noqa E402
110108
from numba_dpex.dpnp_iface import dpnpimpl # noqa E402
111109

110+
# Importing NdRange and Range into numba_dpex for
111+
# backward compatibility
112+
from numba_dpex.kernel_api import NdRange, Range # noqa E402
113+
112114
from .core.targets import dpjit_target, kernel_target # noqa E402
113115
from .decorators import dpjit, func, kernel # noqa E402
114116
from .ocl.stubs import ( # noqa E402

numba_dpex/core/boxing/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# SPDX-FileCopyrightText: 2023 - 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""Contains the ``box`` and ``unbox`` functions for numba_dpex types that are
6+
passable as arguments to a kernel or dpjit decorated function.
7+
"""
8+
9+
from .ranges import *
10+
from .usm_ndarray import *

numba_dpex/core/types/range_types.py renamed to numba_dpex/core/boxing/ranges.py

Lines changed: 3 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -4,55 +4,12 @@
44

55
from contextlib import ExitStack
66

7-
from numba.core import cgutils, errors, types
7+
from numba.core import cgutils, types
88
from numba.core.datamodel import default_manager
99
from numba.extending import NativeValue, box, unbox
1010

11-
from ..kernel_interface.indexers import NdRange, Range
12-
13-
14-
class RangeType(types.Type):
15-
"""Numba-dpex type corresponding to
16-
:class:`numba_dpex.core.kernel_interface.indexers.Range`
17-
"""
18-
19-
def __init__(self, ndim: int):
20-
self._ndim = ndim
21-
if ndim < 1 or ndim > 3:
22-
raise errors.TypingError(
23-
"RangeType can only have 1,2, or 3 dimensions"
24-
)
25-
super(RangeType, self).__init__(name="Range<" + str(ndim) + ">")
26-
27-
@property
28-
def ndim(self):
29-
return self._ndim
30-
31-
@property
32-
def key(self):
33-
return self._ndim
34-
35-
36-
class NdRangeType(types.Type):
37-
"""Numba-dpex type corresponding to
38-
:class:`numba_dpex.core.kernel_interface.indexers.NdRange`
39-
"""
40-
41-
def __init__(self, ndim: int):
42-
self._ndim = ndim
43-
if ndim < 1 or ndim > 3:
44-
raise errors.TypingError(
45-
"RangeType can only have 1,2, or 3 dimensions"
46-
)
47-
super(NdRangeType, self).__init__(name="NdRange<" + str(ndim) + ">")
48-
49-
@property
50-
def ndim(self):
51-
return self._ndim
52-
53-
@property
54-
def key(self):
55-
return self._ndim
11+
from numba_dpex.core.types import NdRangeType, RangeType
12+
from numba_dpex.kernel_api import NdRange, Range
5613

5714

5815
@unbox(RangeType)

numba_dpex/core/boxing/usm_ndarray.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# SPDX-FileCopyrightText: 2023 - 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from contextlib import ExitStack
6+
7+
from numba.core import cgutils, types
8+
from numba.core.datamodel import default_manager
9+
from numba.core.errors import NumbaNotImplementedError
10+
from numba.extending import NativeValue, box, unbox
11+
from numba.np import numpy_support
12+
13+
from numba_dpex.core.exceptions import UnreachableError
14+
from numba_dpex.core.runtime import context as dpexrt
15+
from numba_dpex.core.types import USMNdArray
16+
from numba_dpex.kernel_api import NdRange, Range
17+
18+
19+
@unbox(USMNdArray)
20+
def unbox_dpnp_nd_array(typ, obj, c):
21+
"""Converts a dpctl.tensor.usm_ndarray/dpnp.ndarray object to a Numba-dpex
22+
internal array structure.
23+
24+
Args:
25+
typ : The Numba type of the PyObject
26+
obj : The actual PyObject to be unboxed
27+
c : The unboxing context
28+
29+
Returns: A NativeValue object representing an unboxed
30+
dpctl.tensor.usm_ndarray/dpnp.ndarray
31+
"""
32+
# Reusing the numba.core.base.BaseContext's make_array function to get a
33+
# struct allocated. The same struct is used for numpy.ndarray
34+
# and dpnp.ndarray. It is possible to do so, as the extra information
35+
# specific to dpnp.ndarray such as sycl_queue is inferred statically and
36+
# stored as part of the DpnpNdArray type.
37+
38+
# --------------- Original Numba comment from @ubox(types.Array)
39+
#
40+
# This is necessary because unbox_buffer() does not work on some
41+
# dtypes, e.g. datetime64 and timedelta64.
42+
# TODO check matching dtype.
43+
# currently, mismatching dtype will still work and causes
44+
# potential memory corruption
45+
#
46+
# --------------- End of Numba comment from @ubox(types.Array)
47+
nativearycls = c.context.make_array(typ)
48+
nativeary = nativearycls(c.context, c.builder)
49+
aryptr = nativeary._getpointer()
50+
51+
ptr = c.builder.bitcast(aryptr, c.pyapi.voidptr)
52+
# FIXME : We need to check if Numba_RT as well as DPEX RT are enabled.
53+
if c.context.enable_nrt:
54+
dpexrtCtx = dpexrt.DpexRTContext(c.context)
55+
errcode = dpexrtCtx.arraystruct_from_python(c.pyapi, obj, ptr)
56+
else:
57+
raise UnreachableError
58+
59+
# TODO: here we have minimal typechecking by the itemsize.
60+
# need to do better
61+
try:
62+
expected_itemsize = numpy_support.as_dtype(typ.dtype).itemsize
63+
except NumbaNotImplementedError:
64+
# Don't check types that can't be `as_dtype()`-ed
65+
itemsize_mismatch = cgutils.false_bit
66+
else:
67+
expected_itemsize = nativeary.itemsize.type(expected_itemsize)
68+
itemsize_mismatch = c.builder.icmp_unsigned(
69+
"!=",
70+
nativeary.itemsize,
71+
expected_itemsize,
72+
)
73+
74+
failed = c.builder.or_(
75+
cgutils.is_not_null(c.builder, errcode),
76+
itemsize_mismatch,
77+
)
78+
# Handle error
79+
with c.builder.if_then(failed, likely=False):
80+
c.pyapi.err_set_string(
81+
"PyExc_TypeError",
82+
"can't unbox usm array from PyObject into "
83+
"native value. The object maybe of a "
84+
"different type",
85+
)
86+
return NativeValue(c.builder.load(aryptr), is_error=failed)
87+
88+
89+
@box(USMNdArray)
90+
def box_array(typ, val, c):
91+
"""Boxes a NativeValue representation of USMNdArray/DpnpNdArray type into a
92+
dpctl.tensor.usm_ndarray/dpnp.ndarray PyObject
93+
94+
Args:
95+
typ: The representation of the USMNdArray/DpnpNdArray type.
96+
val: A native representation of a Numba USMNdArray/DpnpNdArray type
97+
object.
98+
c: The boxing context.
99+
100+
Returns: A Pyobject for a dpctl.tensor.usm_ndarray/dpnp.ndarray boxed from
101+
the Numba-dpex native value.
102+
"""
103+
if c.context.enable_nrt:
104+
np_dtype = numpy_support.as_dtype(typ.dtype)
105+
dtypeptr = c.env_manager.read_const(c.env_manager.add_const(np_dtype))
106+
dpexrtCtx = dpexrt.DpexRTContext(c.context)
107+
newary = dpexrtCtx.usm_ndarray_to_python_acqref(
108+
c.pyapi, typ, val, dtypeptr
109+
)
110+
111+
if not newary:
112+
c.pyapi.err_set_string(
113+
"PyExc_TypeError",
114+
"could not box native array into a dpnp.ndarray PyObject.",
115+
)
116+
117+
# Steals NRT ref
118+
# Refer:
119+
# numba.core.base.nrt -> numba.core.runtime.context -> decref
120+
# The `NRT_decref` function is generated directly as LLVM IR inside
121+
# numba.core.runtime.nrtdynmod.py
122+
c.context.nrt.decref(c.builder, typ, val)
123+
124+
return newary
125+
else:
126+
raise UnreachableError

numba_dpex/core/datamodel/models.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,9 +196,7 @@ def __init__(self, dmm, fe_type):
196196

197197

198198
class RangeModel(StructModel):
199-
"""The native data model for a
200-
numba_dpex.core.kernel_interface.indexers.Range PyObject.
201-
"""
199+
"""The numba_dpex data model for a numba_dpex.kernel_api.Range PyObject."""
202200

203201
def __init__(self, dmm, fe_type):
204202
members = [
@@ -216,8 +214,8 @@ def flattened_field_count(self):
216214

217215

218216
class NdRangeModel(StructModel):
219-
"""The native data model for a
220-
numba_dpex.core.kernel_interface.indexers.NdRange PyObject.
217+
"""
218+
The numba_dpex data model for a numba_dpex.kernel_api.NdRange PyObject.
221219
"""
222220

223221
def __init__(self, dmm, fe_type):

0 commit comments

Comments
 (0)