Skip to content

Commit f0526eb

Browse files
author
Diptorup Deb
committed
Adds Range and NdRange as supported types in numba_dpex.dpjit.
- The numba_dpex.core.kernel_interface.indexers.Range and NdRange classes can now be used inside a numba_dpex.dpjit decorated function, as arguments to a dpjit decorated function, and returned from a dpjit decorated function. - Defines a datamodel, type class, overloads for the constructors, box, and unbox functions for both classes.
1 parent a39cdb5 commit f0526eb

File tree

5 files changed

+545
-0
lines changed

5 files changed

+545
-0
lines changed

numba_dpex/core/datamodel/models.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
DpctlSyclEvent,
1515
DpctlSyclQueue,
1616
DpnpNdArray,
17+
NdRangeType,
18+
RangeType,
1719
USMNdArray,
1820
)
1921

@@ -195,6 +197,39 @@ def __init__(self, dmm, fe_type):
195197
super(SyclEventModel, self).__init__(dmm, fe_type, members)
196198

197199

200+
class RangeModel(StructModel):
201+
"""The native data model for a
202+
numba_dpex.core.kernel_interface.indexers.Range PyObject.
203+
"""
204+
205+
def __init__(self, dmm, fe_type):
206+
members = [
207+
("ndim", types.int64),
208+
("dim0", types.int64),
209+
("dim1", types.int64),
210+
("dim2", types.int64),
211+
]
212+
super(RangeModel, self).__init__(dmm, fe_type, members)
213+
214+
215+
class NdRangeModel(StructModel):
216+
"""The native data model for a
217+
numba_dpex.core.kernel_interface.indexers.NdRange PyObject.
218+
"""
219+
220+
def __init__(self, dmm, fe_type):
221+
members = [
222+
("ndim", types.int64),
223+
("gdim0", types.int64),
224+
("gdim1", types.int64),
225+
("gdim2", types.int64),
226+
("ldim0", types.int64),
227+
("ldim1", types.int64),
228+
("ldim2", types.int64),
229+
]
230+
super(NdRangeModel, self).__init__(dmm, fe_type, members)
231+
232+
198233
def _init_data_model_manager() -> datamodel.DataModelManager:
199234
"""Initializes a DpexKernelTarget-specific data model manager.
200235
@@ -249,3 +284,8 @@ def _init_data_model_manager() -> datamodel.DataModelManager:
249284

250285
# Register the DpctlSyclEvent type
251286
register_model(DpctlSyclEvent)(SyclEventModel)
287+
# Register the RangeType type
288+
register_model(RangeType)(RangeModel)
289+
290+
# Register the NdRangeType type
291+
register_model(NdRangeType)(NdRangeModel)

numba_dpex/core/kernel_interface/indexers.py

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44

55
from collections.abc import Iterable
66

7+
from llvmlite import ir as llvmir
8+
from numba.core import cgutils, errors, types
9+
from numba.core.datamodel import default_manager
10+
from numba.extending import intrinsic, overload
11+
712

813
class Range(tuple):
914
"""A data structure to encapsulate a single kernel launch parameter.
@@ -74,6 +79,50 @@ def size(self):
7479
else:
7580
return self[0]
7681

82+
@property
83+
def ndim(self) -> int:
84+
"""Returns the rank of a Range object.
85+
86+
Returns:
87+
int: Number of dimensions in the Range object
88+
"""
89+
return len(self)
90+
91+
@property
92+
def dim0(self) -> int:
93+
"""Return the extent of the first dimension for the Range object.
94+
95+
Returns:
96+
int: Extent of first dimension for the Range object
97+
"""
98+
return self[0]
99+
100+
@property
101+
def dim1(self) -> int:
102+
"""Return the extent of the second dimension for the Range object.
103+
104+
Returns:
105+
int: Extent of second dimension for the Range object or -1 for 1D
106+
Range
107+
"""
108+
try:
109+
return self[1]
110+
except:
111+
return -1
112+
113+
@property
114+
def dim2(self) -> int:
115+
"""Return the extent of the second dimension for the Range object.
116+
117+
Returns:
118+
int: Extent of second dimension for the Range object or -1 for 1D or
119+
2D Range
120+
"""
121+
try:
122+
return self[2]
123+
except:
124+
return -1
125+
77126

78127
class NdRange:
79128
"""A class to encapsulate all kernel launch parameters.
@@ -169,3 +218,157 @@ def __repr__(self):
169218
str: str representation for NdRange class.
170219
"""
171220
return self.__str__()
221+
222+
223+
@intrinsic
224+
def _intrin_range_alloc(typingctx, ty_dim0, ty_dim1, ty_dim2, ty_range):
225+
ty_retty = ty_range.instance_type
226+
sig = ty_retty(
227+
ty_dim0,
228+
ty_dim1,
229+
ty_dim2,
230+
ty_range,
231+
)
232+
233+
def codegen(context, builder, sig, args):
234+
typ = sig.return_type
235+
dim0, dim1, dim2, _ = args
236+
range_struct = cgutils.create_struct_proxy(typ)(context, builder)
237+
range_struct.dim0 = dim0
238+
239+
if not isinstance(sig.args[1], types.NoneType):
240+
range_struct.dim1 = dim1
241+
else:
242+
range_struct.dim1 = llvmir.Constant(llvmir.types.IntType(64), -1)
243+
244+
if not isinstance(sig.args[2], types.NoneType):
245+
range_struct.dim2 = dim2
246+
else:
247+
range_struct.dim2 = llvmir.Constant(llvmir.types.IntType(64), -1)
248+
249+
range_struct.ndim = llvmir.Constant(llvmir.types.IntType(64), typ.ndim)
250+
251+
return range_struct._getvalue()
252+
253+
return sig, codegen
254+
255+
256+
@intrinsic
257+
def _intrin_ndrange_alloc(
258+
typingctx, ty_global_range, ty_local_range, ty_ndrange
259+
):
260+
ty_retty = ty_ndrange.instance_type
261+
sig = ty_retty(
262+
ty_global_range,
263+
ty_local_range,
264+
ty_ndrange,
265+
)
266+
range_datamodel = default_manager.lookup(ty_global_range)
267+
268+
def codegen(context, builder, sig, args):
269+
typ = sig.return_type
270+
271+
global_range, local_range, _ = args
272+
ndrange_struct = cgutils.create_struct_proxy(typ)(context, builder)
273+
ndrange_struct.ndim = llvmir.Constant(
274+
llvmir.types.IntType(64), typ.ndim
275+
)
276+
ndrange_struct.gdim0 = builder.extract_value(
277+
global_range,
278+
range_datamodel.get_field_position("dim0"),
279+
)
280+
ndrange_struct.gdim1 = builder.extract_value(
281+
global_range,
282+
range_datamodel.get_field_position("dim1"),
283+
)
284+
ndrange_struct.gdim2 = builder.extract_value(
285+
global_range,
286+
range_datamodel.get_field_position("dim2"),
287+
)
288+
ndrange_struct.ldim0 = builder.extract_value(
289+
local_range,
290+
range_datamodel.get_field_position("dim0"),
291+
)
292+
ndrange_struct.ldim1 = builder.extract_value(
293+
local_range,
294+
range_datamodel.get_field_position("dim1"),
295+
)
296+
ndrange_struct.ldim2 = builder.extract_value(
297+
local_range,
298+
range_datamodel.get_field_position("dim2"),
299+
)
300+
301+
return ndrange_struct._getvalue()
302+
303+
return sig, codegen
304+
305+
306+
@overload(Range)
307+
def _ol_range_init(dim0, dim1=None, dim2=None):
308+
"""Numba overload of the Range constructor to make it usable inside an
309+
njit and dpjit decorated function.
310+
311+
"""
312+
from numba_dpex.core.types import RangeType
313+
314+
ndims = 1
315+
ty_optional_dims = (dim1, dim2)
316+
317+
# A Range should at least have the 0th dimension populated
318+
if not isinstance(dim0, types.Integer):
319+
raise errors.TypingError(
320+
"Expected a Range's dimension should to be an Integer value, but "
321+
"encountered " + dim0.name
322+
)
323+
324+
for ty_dim in ty_optional_dims:
325+
if isinstance(ty_dim, types.Integer):
326+
ndims += 1
327+
elif ty_dim is not None:
328+
raise errors.TypingError(
329+
"Expected a Range's dimension to be an Integer value, "
330+
f"but {type(ty_dim)} was provided."
331+
)
332+
333+
ret_ty = RangeType(ndims)
334+
335+
def impl(dim0, dim1=None, dim2=None):
336+
return _intrin_range_alloc(dim0, dim1, dim2, ret_ty)
337+
338+
return impl
339+
340+
341+
@overload(NdRange)
342+
def _ol_ndrange_init(global_range, local_range):
343+
"""Numba overload of the NdRange constructor to make it usable inside an
344+
njit and dpjit decorated function.
345+
346+
"""
347+
from numba_dpex.core.exceptions import UnmatchedNumberOfRangeDimsError
348+
from numba_dpex.core.types import NdRangeType, RangeType
349+
350+
if not isinstance(global_range, RangeType):
351+
raise errors.TypingError(
352+
"Only global range values specified as a Range are "
353+
"supported inside dpjit"
354+
)
355+
356+
if not isinstance(local_range, RangeType):
357+
raise errors.TypingError(
358+
"Only local range values specified as a Range are "
359+
"supported inside dpjit"
360+
)
361+
362+
if not global_range.ndim == local_range.ndim:
363+
raise UnmatchedNumberOfRangeDimsError(
364+
kernel_name="",
365+
global_ndims=global_range.ndim,
366+
local_ndims=local_range.ndim,
367+
)
368+
369+
ret_ty = NdRangeType(global_range.ndim)
370+
371+
def impl(global_range, local_range):
372+
return _intrin_ndrange_alloc(global_range, local_range, ret_ty)
373+
374+
return impl

numba_dpex/core/types/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
uint64,
2727
void,
2828
)
29+
from .range_types import NdRangeType, RangeType
2930
from .usm_ndarray_type import USMNdArray
3031

3132
usm_ndarray = USMNdArray
@@ -35,6 +36,8 @@
3536
"DpctlSyclQueue",
3637
"DpctlSyclEvent",
3738
"DpnpNdArray",
39+
"RangeType",
40+
"NdRangeType",
3841
"USMNdArray",
3942
"none",
4043
"boolean",

0 commit comments

Comments
 (0)