Skip to content

Commit f99d3b0

Browse files
author
Diptorup Deb
committed
Refactoring the indexer classes.
- Splits up the numba_dpex.core.kernel_interface.indexers module and moves the Python classes into a new kernel_api.ranges submodule. - Moves the numba-dpex type definitions into core.types.kernel_api_types - Other changes as needed to get everything working
1 parent e7749a0 commit f99d3b0

File tree

17 files changed

+273
-241
lines changed

17 files changed

+273
-241
lines changed

numba_dpex/__init__.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,16 +99,17 @@ def parse_sem_version(version_string: str) -> Tuple[int, int, int]:
9999
# Re-export types itself
100100
import numba_dpex.core.types as types # noqa E402
101101
from numba_dpex.core import config # noqa E402
102-
from numba_dpex.core.kernel_interface.indexers import ( # noqa E402
103-
NdRange,
104-
Range,
105-
)
102+
from numba_dpex.core.kernel_interface import ranges_overloads # noqa E402
106103

107104
# Re-export all type names
108105
from numba_dpex.core.types import * # noqa E402
109106
from numba_dpex.dpctl_iface import _intrinsic # noqa E402
110107
from numba_dpex.dpnp_iface import dpnpimpl # noqa E402
111108

109+
# Importing NdRange and Range into numba_dpex for
110+
# backward compatibility
111+
from numba_dpex.kernel_api import NdRange, Range # noqa E402
112+
112113
from .core.targets import dpjit_target, kernel_target # noqa E402
113114
from .decorators import dpjit, func, kernel # noqa E402
114115
from .ocl.stubs import ( # noqa E402

numba_dpex/core/types/range_types.py renamed to numba_dpex/core/boxing.py

Lines changed: 3 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -4,55 +4,12 @@
44

55
from contextlib import ExitStack
66

7-
from numba.core import cgutils, errors, types
7+
from numba.core import cgutils, types
88
from numba.core.datamodel import default_manager
99
from numba.extending import NativeValue, box, unbox
1010

11-
from ..kernel_interface.indexers import NdRange, Range
12-
13-
14-
class RangeType(types.Type):
15-
"""Numba-dpex type corresponding to
16-
:class:`numba_dpex.core.kernel_interface.indexers.Range`
17-
"""
18-
19-
def __init__(self, ndim: int):
20-
self._ndim = ndim
21-
if ndim < 1 or ndim > 3:
22-
raise errors.TypingError(
23-
"RangeType can only have 1,2, or 3 dimensions"
24-
)
25-
super(RangeType, self).__init__(name="Range<" + str(ndim) + ">")
26-
27-
@property
28-
def ndim(self):
29-
return self._ndim
30-
31-
@property
32-
def key(self):
33-
return self._ndim
34-
35-
36-
class NdRangeType(types.Type):
37-
"""Numba-dpex type corresponding to
38-
:class:`numba_dpex.core.kernel_interface.indexers.NdRange`
39-
"""
40-
41-
def __init__(self, ndim: int):
42-
self._ndim = ndim
43-
if ndim < 1 or ndim > 3:
44-
raise errors.TypingError(
45-
"RangeType can only have 1,2, or 3 dimensions"
46-
)
47-
super(NdRangeType, self).__init__(name="NdRange<" + str(ndim) + ">")
48-
49-
@property
50-
def ndim(self):
51-
return self._ndim
52-
53-
@property
54-
def key(self):
55-
return self._ndim
11+
from numba_dpex.core.types import NdRangeType, RangeType
12+
from numba_dpex.kernel_api import NdRange, Range
5613

5714

5815
@unbox(RangeType)

numba_dpex/core/datamodel/models.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,9 +196,7 @@ def __init__(self, dmm, fe_type):
196196

197197

198198
class RangeModel(StructModel):
199-
"""The native data model for a
200-
numba_dpex.core.kernel_interface.indexers.Range PyObject.
201-
"""
199+
"""The numba_dpex data model for a numba_dpex.kernel_api.Range PyObject."""
202200

203201
def __init__(self, dmm, fe_type):
204202
members = [
@@ -216,8 +214,8 @@ def flattened_field_count(self):
216214

217215

218216
class NdRangeModel(StructModel):
219-
"""The native data model for a
220-
numba_dpex.core.kernel_interface.indexers.NdRange PyObject.
217+
"""
218+
The numba_dpex data model for a numba_dpex.kernel_api.NdRange PyObject.
221219
"""
222220

223221
def __init__(self, dmm, fe_type):
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# SPDX-FileCopyrightText: 2023 - 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from llvmlite import ir as llvmir
6+
from numba.core import cgutils, errors, types
7+
from numba.core.datamodel import default_manager
8+
from numba.extending import intrinsic, overload
9+
10+
from numba_dpex.kernel_api import NdRange, Range
11+
12+
# can't import name because of the circular import
13+
DPEX_TARGET_NAME = "dpex"
14+
15+
16+
@intrinsic(target=DPEX_TARGET_NAME)
17+
def _intrin_range_alloc(typingctx, ty_dim0, ty_dim1, ty_dim2, ty_range):
18+
ty_retty = ty_range.instance_type
19+
sig = ty_retty(
20+
ty_dim0,
21+
ty_dim1,
22+
ty_dim2,
23+
ty_range,
24+
)
25+
26+
def codegen(context, builder, sig, args):
27+
typ = sig.return_type
28+
dim0, dim1, dim2, _ = args
29+
range_struct = cgutils.create_struct_proxy(typ)(context, builder)
30+
range_struct.dim0 = dim0
31+
32+
if not isinstance(sig.args[1], types.NoneType):
33+
range_struct.dim1 = dim1
34+
else:
35+
range_struct.dim1 = llvmir.Constant(
36+
llvmir.types.IntType(64), Range.UNDEFINED_DIMENSION
37+
)
38+
39+
if not isinstance(sig.args[2], types.NoneType):
40+
range_struct.dim2 = dim2
41+
else:
42+
range_struct.dim2 = llvmir.Constant(
43+
llvmir.types.IntType(64), Range.UNDEFINED_DIMENSION
44+
)
45+
46+
range_struct.ndim = llvmir.Constant(llvmir.types.IntType(64), typ.ndim)
47+
48+
return range_struct._getvalue()
49+
50+
return sig, codegen
51+
52+
53+
@intrinsic(target=DPEX_TARGET_NAME)
54+
def _intrin_ndrange_alloc(
55+
typingctx, ty_global_range, ty_local_range, ty_ndrange
56+
):
57+
ty_retty = ty_ndrange.instance_type
58+
sig = ty_retty(
59+
ty_global_range,
60+
ty_local_range,
61+
ty_ndrange,
62+
)
63+
range_datamodel = default_manager.lookup(ty_global_range)
64+
65+
def codegen(context, builder, sig, args):
66+
typ = sig.return_type
67+
68+
global_range, local_range, _ = args
69+
ndrange_struct = cgutils.create_struct_proxy(typ)(context, builder)
70+
ndrange_struct.ndim = llvmir.Constant(
71+
llvmir.types.IntType(64), typ.ndim
72+
)
73+
ndrange_struct.gdim0 = builder.extract_value(
74+
global_range,
75+
range_datamodel.get_field_position("dim0"),
76+
)
77+
ndrange_struct.gdim1 = builder.extract_value(
78+
global_range,
79+
range_datamodel.get_field_position("dim1"),
80+
)
81+
ndrange_struct.gdim2 = builder.extract_value(
82+
global_range,
83+
range_datamodel.get_field_position("dim2"),
84+
)
85+
ndrange_struct.ldim0 = builder.extract_value(
86+
local_range,
87+
range_datamodel.get_field_position("dim0"),
88+
)
89+
ndrange_struct.ldim1 = builder.extract_value(
90+
local_range,
91+
range_datamodel.get_field_position("dim1"),
92+
)
93+
ndrange_struct.ldim2 = builder.extract_value(
94+
local_range,
95+
range_datamodel.get_field_position("dim2"),
96+
)
97+
98+
return ndrange_struct._getvalue()
99+
100+
return sig, codegen
101+
102+
103+
@overload(Range, target=DPEX_TARGET_NAME)
104+
def _ol_range_init(dim0, dim1=None, dim2=None):
105+
"""Numba overload of the Range constructor to make it usable inside an
106+
njit and dpjit decorated function.
107+
108+
"""
109+
from numba_dpex.core.types import RangeType
110+
111+
ndims = 1
112+
ty_optional_dims = (dim1, dim2)
113+
114+
# A Range should at least have the 0th dimension populated
115+
if not isinstance(dim0, types.Integer):
116+
raise errors.TypingError(
117+
"Expected a Range's dimension should to be an Integer value, but "
118+
"encountered " + dim0.name
119+
)
120+
121+
for ty_dim in ty_optional_dims:
122+
if isinstance(ty_dim, types.Integer):
123+
ndims += 1
124+
elif ty_dim is not None:
125+
raise errors.TypingError(
126+
"Expected a Range's dimension to be an Integer value, "
127+
f"but {type(ty_dim)} was provided."
128+
)
129+
130+
ret_ty = RangeType(ndims)
131+
132+
def impl(dim0, dim1=None, dim2=None):
133+
return _intrin_range_alloc(dim0, dim1, dim2, ret_ty)
134+
135+
return impl
136+
137+
138+
@overload(NdRange, target=DPEX_TARGET_NAME)
139+
def _ol_ndrange_init(global_range, local_range):
140+
"""Numba overload of the NdRange constructor to make it usable inside an
141+
njit and dpjit decorated function.
142+
143+
"""
144+
from numba_dpex.core.exceptions import UnmatchedNumberOfRangeDimsError
145+
from numba_dpex.core.types import NdRangeType, RangeType
146+
147+
if not isinstance(global_range, RangeType):
148+
raise errors.TypingError(
149+
"Only global range values specified as a Range are "
150+
"supported inside dpjit"
151+
)
152+
153+
if not isinstance(local_range, RangeType):
154+
raise errors.TypingError(
155+
"Only local range values specified as a Range are "
156+
"supported inside dpjit"
157+
)
158+
159+
if not global_range.ndim == local_range.ndim:
160+
raise UnmatchedNumberOfRangeDimsError(
161+
kernel_name="",
162+
global_ndims=global_range.ndim,
163+
local_ndims=local_range.ndim,
164+
)
165+
166+
ret_ty = NdRangeType(global_range.ndim)
167+
168+
def impl(global_range, local_range):
169+
return _intrin_ndrange_alloc(global_range, local_range, ret_ty)
170+
171+
return impl

numba_dpex/core/types/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .array_type import Array
66
from .dpctl_types import DpctlSyclEvent, DpctlSyclQueue
77
from .dpnp_ndarray_type import DpnpNdArray
8+
from .kernel_api_types.range_types import NdRangeType, RangeType
89
from .numba_types_short_names import (
910
b1,
1011
bool_,
@@ -26,7 +27,6 @@
2627
uint64,
2728
void,
2829
)
29-
from .range_types import NdRangeType, RangeType
3030
from .usm_ndarray_type import USMNdArray
3131

3232
usm_ndarray = USMNdArray
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# SPDX-FileCopyrightText: 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# SPDX-FileCopyrightText: 2023 - 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from contextlib import ExitStack
6+
7+
from numba.core import cgutils, errors, types
8+
9+
10+
class RangeType(types.Type):
11+
"""Numba-dpex type corresponding to
12+
:class:`numba_dpex.kernel_api.ranges.Range`
13+
"""
14+
15+
def __init__(self, ndim: int):
16+
self._ndim = ndim
17+
if ndim < 1 or ndim > 3:
18+
raise errors.TypingError(
19+
"RangeType can only have 1,2, or 3 dimensions"
20+
)
21+
super(RangeType, self).__init__(name="Range<" + str(ndim) + ">")
22+
23+
@property
24+
def ndim(self):
25+
return self._ndim
26+
27+
@property
28+
def key(self):
29+
return self._ndim
30+
31+
32+
class NdRangeType(types.Type):
33+
"""Numba-dpex type corresponding to
34+
:class:`numba_dpex.kernel_api.ranges.NdRange`
35+
"""
36+
37+
def __init__(self, ndim: int):
38+
self._ndim = ndim
39+
if ndim < 1 or ndim > 3:
40+
raise errors.TypingError(
41+
"RangeType can only have 1,2, or 3 dimensions"
42+
)
43+
super(NdRangeType, self).__init__(name="NdRange<" + str(ndim) + ">")
44+
45+
@property
46+
def ndim(self):
47+
return self._ndim
48+
49+
@property
50+
def key(self):
51+
return self._ndim

numba_dpex/core/typing/typeof.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88
from numba.extending import typeof_impl
99
from numba.np import numpy_support
1010

11+
from numba_dpex.kernel_api.ranges import NdRange, Range
1112
from numba_dpex.utils.constants import address_space
1213

13-
from ..kernel_interface.indexers import NdRange, Range
1414
from ..types.dpctl_types import DpctlSyclEvent, DpctlSyclQueue
1515
from ..types.dpnp_ndarray_type import DpnpNdArray
16-
from ..types.range_types import NdRangeType, RangeType
16+
from ..types.kernel_api_types.range_types import NdRangeType, RangeType
1717
from ..types.usm_ndarray_type import USMNdArray
1818

1919

numba_dpex/core/utils/kernel_launcher.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@
2121
from numba_dpex.core.exceptions import UnreachableError
2222
from numba_dpex.core.runtime.context import DpexRTContext
2323
from numba_dpex.core.types import USMNdArray
24-
from numba_dpex.core.types.range_types import NdRangeType, RangeType
24+
from numba_dpex.core.types.kernel_api_types.range_types import (
25+
NdRangeType,
26+
RangeType,
27+
)
2528
from numba_dpex.dpctl_iface import libsyclinterface_bindings as sycl
2629
from numba_dpex.dpctl_iface._helpers import numba_type_to_dpctl_typenum
2730
from numba_dpex.utils import create_null_ptr

numba_dpex/experimental/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88

99
from numba.core.imputils import Registry
1010

11+
# Temporary so that Range and NdRange work in experimental call_kernel
12+
from numba_dpex.core import boxing
13+
1114
from ._kernel_dpcpp_spirv_overloads import _atomic_ref_overloads
1215
from .decorators import device_func, kernel
1316
from .kernel_dispatcher import KernelDispatcher

0 commit comments

Comments
 (0)