Skip to content

Commit a3c8646

Browse files
author
Diptorup Deb
committed
New data model for DpnpNdArray type objects.
- Creates a new data model DpnpNdArrayModel to represent DpnpNdArray type objects natively. The data model differs from numba's ArrayModel by having an extra member to store a sycl::queue pointer. - Introduces a _usmarraystruct.h header to define the C struct for the DpnpNdArrayModel. - Renames numba_dpex.core.datamodel.models.ArrayModel to USMArrayModel. - Updates kernel launcher and parfor lowering functions to account for the new data model.
1 parent cd81ca3 commit a3c8646

File tree

8 files changed

+167
-64
lines changed

8 files changed

+167
-64
lines changed

numba_dpex/core/datamodel/models.py

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
from numba.core import datamodel, types
6-
from numba.core.datamodel.models import ArrayModel as DpnpNdArrayModel
76
from numba.core.datamodel.models import PrimitiveModel, StructModel
87
from numba.core.extending import register_model
98

9+
from numba_dpex.core.exceptions import UnreachableError
1010
from numba_dpex.utils import address_space
1111

1212
from ..types import Array, DpctlSyclQueue, DpnpNdArray, USMNdArray
@@ -23,7 +23,7 @@ def __init__(self, dmm, fe_type):
2323
super(GenericPointerModel, self).__init__(dmm, fe_type, be_type)
2424

2525

26-
class ArrayModel(StructModel):
26+
class USMArrayModel(StructModel):
2727
"""A data model to represent a Dpex's array types in LLVM IR.
2828
2929
Dpex's ArrayModel is based on Numba's ArrayModel for NumPy arrays. The
@@ -40,18 +40,69 @@ def __init__(self, dmm, fe_type):
4040
),
4141
(
4242
"parent",
43-
types.CPointer(fe_type.dtype, addrspace=fe_type.addrspace),
43+
types.CPointer(types.pyobject, addrspace=fe_type.addrspace),
4444
),
4545
("nitems", types.intp),
4646
("itemsize", types.intp),
4747
(
4848
"data",
4949
types.CPointer(fe_type.dtype, addrspace=fe_type.addrspace),
5050
),
51+
("sycl_queue", types.voidptr),
52+
("shape", types.UniTuple(types.intp, ndim)),
53+
("strides", types.UniTuple(types.intp, ndim)),
54+
]
55+
super(USMArrayModel, self).__init__(dmm, fe_type, members)
56+
57+
58+
class DpnpNdArrayModel(StructModel):
59+
"""Data model for the DpnpNdArray type.
60+
61+
The data model for DpnpNdArray is similar to numb's ArrayModel used for
62+
the numba.types.Array type, with the additional field ``sycl_queue`. The
63+
`sycl_queue` attribute stores the pointer to the C++ sycl::queue object
64+
that was used to allocate memory for numba-dpex's native representation
65+
for an Python object inferred as a DpnpNdArray.
66+
"""
67+
68+
def __init__(self, dmm, fe_type):
69+
ndim = fe_type.ndim
70+
members = [
71+
("meminfo", types.MemInfoPointer(fe_type.dtype)),
72+
("parent", types.pyobject),
73+
("nitems", types.intp),
74+
("itemsize", types.intp),
75+
("data", types.CPointer(fe_type.dtype)),
76+
("sycl_queue", types.voidptr),
5177
("shape", types.UniTuple(types.intp, ndim)),
5278
("strides", types.UniTuple(types.intp, ndim)),
5379
]
54-
super(ArrayModel, self).__init__(dmm, fe_type, members)
80+
super(DpnpNdArrayModel, self).__init__(dmm, fe_type, members)
81+
82+
@property
83+
def flattened_field_count(self):
84+
"""Return the number of fields in an instance of a DpnpNdArrayModel."""
85+
flattened_member_count = 0
86+
members = self._members
87+
for member in members:
88+
if isinstance(member, types.UniTuple):
89+
flattened_member_count += member.count
90+
elif isinstance(
91+
member,
92+
(
93+
types.scalars.Integer,
94+
types.misc.PyObject,
95+
types.misc.RawPointer,
96+
types.misc.CPointer,
97+
types.misc.MemInfoPointer,
98+
),
99+
):
100+
flattened_member_count += 1
101+
else:
102+
print(member, type(member))
103+
raise UnreachableError
104+
105+
return flattened_member_count
55106

56107

57108
class SyclQueueModel(StructModel):
@@ -84,7 +135,7 @@ def __init__(self, dmm, fe_type):
84135
def _init_data_model_manager():
85136
dmm = datamodel.default_manager.copy()
86137
dmm.register(types.CPointer, GenericPointerModel)
87-
dmm.register(Array, ArrayModel)
138+
dmm.register(Array, USMArrayModel)
88139
return dmm
89140

90141

@@ -103,8 +154,8 @@ def _init_data_model_manager():
103154
# object.
104155

105156
# Register the USMNdArray type with the dpex ArrayModel
106-
register_model(USMNdArray)(ArrayModel)
107-
dpex_data_model_manager.register(USMNdArray, ArrayModel)
157+
register_model(USMNdArray)(USMArrayModel)
158+
dpex_data_model_manager.register(USMNdArray, USMArrayModel)
108159

109160
# Register the DpnpNdArray type with the Numba ArrayModel
110161
register_model(DpnpNdArray)(DpnpNdArrayModel)

numba_dpex/core/kernel_interface/arg_pack_unpacker.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def _unpack_usm_array(self, val):
5454
unpacked_array_attrs.append(ctypes.c_longlong(size))
5555
unpacked_array_attrs.append(ctypes.c_longlong(itemsize))
5656
unpacked_array_attrs.append(buf)
57+
# queue: unused and passed as void*
58+
unpacked_array_attrs.append(ctypes.c_size_t(0))
5759
for ax in range(ndim):
5860
unpacked_array_attrs.append(ctypes.c_longlong(shape[ax]))
5961
for ax in range(ndim):

numba_dpex/core/parfors/parfor_lowerer.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import copy
66

77
from llvmlite import ir as llvmir
8-
from numba.core import cgutils, ir, types
8+
from numba.core import ir, types
99
from numba.parfors.parfor import (
1010
find_potential_aliases_parfor,
1111
get_parfor_outputs,
@@ -26,6 +26,8 @@
2626
create_reduction_remainder_kernel_for_parfor,
2727
)
2828

29+
from numba_dpex.core.datamodel.models import dpex_data_model_manager as dpex_dmm
30+
2931
# A global list of kernels to keep the objects alive indefinitely.
3032
keep_alive_kernels = []
3133

@@ -114,8 +116,8 @@ def _build_kernel_arglist(self, kernel_fn, lowerer):
114116
# kernel_fn.kernel_args as arrays get flattened.
115117
for arg_type in kernel_fn.kernel_arg_types:
116118
if isinstance(arg_type, DpnpNdArray):
117-
# FIXME: Remove magic constants
118-
num_flattened_args += 5 + (2 * arg_type.ndim)
119+
datamodel = dpex_dmm.lookup(arg_type)
120+
num_flattened_args += datamodel.flattened_field_count
119121
elif arg_type == types.complex64 or arg_type == types.complex128:
120122
num_flattened_args += 2
121123
else:
@@ -134,15 +136,16 @@ def _build_kernel_arglist(self, kernel_fn, lowerer):
134136
argtype = kernel_fn.kernel_arg_types[arg_num]
135137
llvm_val = _getvar(lowerer, arg)
136138
if isinstance(argtype, DpnpNdArray):
139+
datamodel = dpex_dmm.lookup(arg_type)
137140
self.kernel_builder.build_array_arg(
138141
array_val=llvm_val,
142+
array_data_model=datamodel,
139143
array_rank=argtype.ndim,
140144
arg_list=self.args_list,
141145
args_ty_list=self.args_ty_list,
142146
arg_num=self.kernel_arg_num,
143147
)
144-
# FIXME: Get rid of magic constants
145-
self.kernel_arg_num += 5 + (2 * argtype.ndim)
148+
self.kernel_arg_num += datamodel.flattened_field_count
146149
else:
147150
if argtype == types.complex64:
148151
self.kernel_builder.build_complex_arg(

numba_dpex/core/runtime/_dpexrt_python.c

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
#include "_dbg_printer.h"
2323
#include "_queuestruct.h"
24-
#include "numba/_arraystruct.h"
24+
#include "_usmarraystruct.h"
2525

2626
// forward declarations
2727
static struct PyUSMArrayObject *PyUSMNdArray_ARRAYOBJ(PyObject *obj);
@@ -49,14 +49,14 @@ static NRT_MemInfo *DPEXRT_MemInfo_alloc(npy_intp size,
4949
size_t usm_type,
5050
const DPCTLSyclQueueRef qref);
5151
static void usmndarray_meminfo_dtor(void *ptr, size_t size, void *info);
52-
static PyObject *box_from_arystruct_parent(arystruct_t *arystruct,
52+
static PyObject *box_from_arystruct_parent(usmarystruct_t *arystruct,
5353
int ndim,
5454
PyArray_Descr *descr);
5555

5656
static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
57-
arystruct_t *arystruct);
57+
usmarystruct_t *arystruct);
5858
static PyObject *
59-
DPEXRT_sycl_usm_ndarray_to_python_acqref(arystruct_t *arystruct,
59+
DPEXRT_sycl_usm_ndarray_to_python_acqref(usmarystruct_t *arystruct,
6060
PyTypeObject *retty,
6161
int ndim,
6262
int writeable,
@@ -770,7 +770,7 @@ static npy_intp product_of_shape(npy_intp *shape, npy_intp ndim)
770770
* @return {return} Error code representing success (0) or failure (-1).
771771
*/
772772
static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
773-
arystruct_t *arystruct)
773+
usmarystruct_t *arystruct)
774774
{
775775
struct PyUSMArrayObject *arrayobj = NULL;
776776
int i = 0, j = 0, k = 0, ndim = 0, exp = 0;
@@ -827,6 +827,7 @@ static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
827827
}
828828

829829
arystruct->data = data;
830+
arystruct->sycl_queue = qref;
830831
arystruct->nitems = nitems;
831832
arystruct->itemsize = itemsize;
832833
arystruct->parent = obj;
@@ -892,7 +893,7 @@ static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
892893
* @return {return} A PyObject created from the arystruct_t->parent, if
893894
* the PyObject could not be created return NULL.
894895
*/
895-
static PyObject *box_from_arystruct_parent(arystruct_t *arystruct,
896+
static PyObject *box_from_arystruct_parent(usmarystruct_t *arystruct,
896897
int ndim,
897898
PyArray_Descr *descr)
898899
{
@@ -914,8 +915,10 @@ static PyObject *box_from_arystruct_parent(arystruct_t *arystruct,
914915
}
915916

916917
if ((void *)UsmNDArray_GetData(arrayobj) != arystruct->data) {
917-
DPEXRT_DEBUG(drt_debug_print("DPEXRT-DEBUG: Arrayobj cannot be boxed "
918-
"from parent as data pointer is NULL.\n"));
918+
DPEXRT_DEBUG(drt_debug_print(
919+
"DPEXRT-DEBUG: Arrayobj cannot be boxed "
920+
"from parent as data pointer in the arystruct is not the same as "
921+
"the data pointer in the parent object.\n"));
919922
return NULL;
920923
}
921924

@@ -978,7 +981,7 @@ static PyObject *box_from_arystruct_parent(arystruct_t *arystruct,
978981
*
979982
*/
980983
static PyObject *
981-
DPEXRT_sycl_usm_ndarray_to_python_acqref(arystruct_t *arystruct,
984+
DPEXRT_sycl_usm_ndarray_to_python_acqref(usmarystruct_t *arystruct,
982985
PyTypeObject *retty,
983986
int ndim,
984987
int writeable,
@@ -1094,8 +1097,7 @@ DPEXRT_sycl_usm_ndarray_to_python_acqref(arystruct_t *arystruct,
10941097
typenum = descr->type_num;
10951098
usm_ndarr_obj = UsmNDArray_MakeFromPtr(
10961099
ndim, shape, typenum, strides, (DPCTLSyclUSMRef)arystruct->data,
1097-
(DPCTLSyclQueueRef)miobj->meminfo->external_allocator->opaque_data, 0,
1098-
(PyObject *)miobj);
1100+
(DPCTLSyclQueueRef)arystruct->sycl_queue, 0, (PyObject *)miobj);
10991101

11001102
if (usm_ndarr_obj == NULL ||
11011103
!PyObject_TypeCheck(usm_ndarr_obj, &PyUSMArrayType))
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
5+
//===----------------------------------------------------------------------===//
6+
///
7+
/// \file
8+
/// Defines the numba-dpex native representation for a dpctl.tensor.usm_ndarray
9+
///
10+
//===----------------------------------------------------------------------===//
11+
12+
#pragma once
13+
14+
#include <Python.h>
15+
#include <numpy/npy_common.h>
16+
17+
typedef struct
18+
{
19+
void *meminfo;
20+
PyObject *parent;
21+
npy_intp nitems;
22+
npy_intp itemsize;
23+
void *data;
24+
void *sycl_queue;
25+
26+
npy_intp shape_and_strides[];
27+
} usmarystruct_t;

0 commit comments

Comments
 (0)