Skip to content

Commit 81af300

Browse files
author
Diptorup Deb
authored
Merge pull request #1083 from IntelPython/feature/improved_sycl_queue_support
Feature/improved sycl queue support
2 parents cd81ca3 + 4877352 commit 81af300

File tree

18 files changed

+542
-175
lines changed

18 files changed

+542
-175
lines changed

numba_dpex/core/datamodel/models.py

Lines changed: 64 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
from numba.core import datamodel, types
6-
from numba.core.datamodel.models import ArrayModel as DpnpNdArrayModel
76
from numba.core.datamodel.models import PrimitiveModel, StructModel
87
from numba.core.extending import register_model
98

9+
from numba_dpex.core.exceptions import UnreachableError
1010
from numba_dpex.utils import address_space
1111

1212
from ..types import Array, DpctlSyclQueue, DpnpNdArray, USMNdArray
@@ -23,7 +23,7 @@ def __init__(self, dmm, fe_type):
2323
super(GenericPointerModel, self).__init__(dmm, fe_type, be_type)
2424

2525

26-
class ArrayModel(StructModel):
26+
class USMArrayModel(StructModel):
2727
"""A data model to represent a Dpex's array types in LLVM IR.
2828
2929
Dpex's ArrayModel is based on Numba's ArrayModel for NumPy arrays. The
@@ -40,18 +40,75 @@ def __init__(self, dmm, fe_type):
4040
),
4141
(
4242
"parent",
43-
types.CPointer(fe_type.dtype, addrspace=fe_type.addrspace),
43+
types.CPointer(types.pyobject, addrspace=fe_type.addrspace),
4444
),
4545
("nitems", types.intp),
4646
("itemsize", types.intp),
4747
(
4848
"data",
4949
types.CPointer(fe_type.dtype, addrspace=fe_type.addrspace),
5050
),
51+
(
52+
"sycl_queue",
53+
types.CPointer(types.void, addrspace=fe_type.addrspace),
54+
),
55+
("shape", types.UniTuple(types.intp, ndim)),
56+
("strides", types.UniTuple(types.intp, ndim)),
57+
]
58+
super(USMArrayModel, self).__init__(dmm, fe_type, members)
59+
60+
61+
class DpnpNdArrayModel(StructModel):
62+
"""Data model for the DpnpNdArray type.
63+
64+
DpnpNdArrayModel is used by the numba_dpex.types.DpnpNdArray type and
65+
abstracts the usmarystruct_t C type defined in
66+
numba_dpex.core.runtime._usmarraystruct.h.
67+
68+
The DpnpNdArrayModel differs from numba's ArrayModel by including an extra
69+
member sycl_queue that maps to _usmarraystruct.sycl_queue pointer. The
70+
_usmarraystruct.sycl_queue pointer stores the C++ sycl::queue pointer that
71+
was used to allocate the data for the dpnp.ndarray represented by an
72+
instance of _usmarraystruct.
73+
"""
74+
75+
def __init__(self, dmm, fe_type):
76+
ndim = fe_type.ndim
77+
members = [
78+
("meminfo", types.MemInfoPointer(fe_type.dtype)),
79+
("parent", types.pyobject),
80+
("nitems", types.intp),
81+
("itemsize", types.intp),
82+
("data", types.CPointer(fe_type.dtype)),
83+
("sycl_queue", types.voidptr),
5184
("shape", types.UniTuple(types.intp, ndim)),
5285
("strides", types.UniTuple(types.intp, ndim)),
5386
]
54-
super(ArrayModel, self).__init__(dmm, fe_type, members)
87+
super(DpnpNdArrayModel, self).__init__(dmm, fe_type, members)
88+
89+
@property
90+
def flattened_field_count(self):
91+
"""Return the number of fields in an instance of a DpnpNdArrayModel."""
92+
flattened_member_count = 0
93+
members = self._members
94+
for member in members:
95+
if isinstance(member, types.UniTuple):
96+
flattened_member_count += member.count
97+
elif isinstance(
98+
member,
99+
(
100+
types.scalars.Integer,
101+
types.misc.PyObject,
102+
types.misc.RawPointer,
103+
types.misc.CPointer,
104+
types.misc.MemInfoPointer,
105+
),
106+
):
107+
flattened_member_count += 1
108+
else:
109+
raise UnreachableError
110+
111+
return flattened_member_count
55112

56113

57114
class SyclQueueModel(StructModel):
@@ -84,7 +141,7 @@ def __init__(self, dmm, fe_type):
84141
def _init_data_model_manager():
85142
dmm = datamodel.default_manager.copy()
86143
dmm.register(types.CPointer, GenericPointerModel)
87-
dmm.register(Array, ArrayModel)
144+
dmm.register(Array, USMArrayModel)
88145
return dmm
89146

90147

@@ -103,8 +160,8 @@ def _init_data_model_manager():
103160
# object.
104161

105162
# Register the USMNdArray type with the dpex ArrayModel
106-
register_model(USMNdArray)(ArrayModel)
107-
dpex_data_model_manager.register(USMNdArray, ArrayModel)
163+
register_model(USMNdArray)(USMArrayModel)
164+
dpex_data_model_manager.register(USMNdArray, USMArrayModel)
108165

109166
# Register the DpnpNdArray type with the Numba ArrayModel
110167
register_model(DpnpNdArray)(DpnpNdArrayModel)

numba_dpex/core/kernel_interface/arg_pack_unpacker.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def _unpack_usm_array(self, val):
5454
unpacked_array_attrs.append(ctypes.c_longlong(size))
5555
unpacked_array_attrs.append(ctypes.c_longlong(itemsize))
5656
unpacked_array_attrs.append(buf)
57+
# queue: unused and passed as void*
58+
unpacked_array_attrs.append(ctypes.c_size_t(0))
5759
for ax in range(ndim):
5860
unpacked_array_attrs.append(ctypes.c_longlong(shape[ax]))
5961
for ax in range(ndim):

numba_dpex/core/parfors/parfor_lowerer.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import copy
66

77
from llvmlite import ir as llvmir
8-
from numba.core import cgutils, ir, types
8+
from numba.core import ir, types
99
from numba.parfors.parfor import (
1010
find_potential_aliases_parfor,
1111
get_parfor_outputs,
@@ -26,6 +26,8 @@
2626
create_reduction_remainder_kernel_for_parfor,
2727
)
2828

29+
from numba_dpex.core.datamodel.models import dpex_data_model_manager as dpex_dmm
30+
2931
# A global list of kernels to keep the objects alive indefinitely.
3032
keep_alive_kernels = []
3133

@@ -114,8 +116,8 @@ def _build_kernel_arglist(self, kernel_fn, lowerer):
114116
# kernel_fn.kernel_args as arrays get flattened.
115117
for arg_type in kernel_fn.kernel_arg_types:
116118
if isinstance(arg_type, DpnpNdArray):
117-
# FIXME: Remove magic constants
118-
num_flattened_args += 5 + (2 * arg_type.ndim)
119+
datamodel = dpex_dmm.lookup(arg_type)
120+
num_flattened_args += datamodel.flattened_field_count
119121
elif arg_type == types.complex64 or arg_type == types.complex128:
120122
num_flattened_args += 2
121123
else:
@@ -134,15 +136,16 @@ def _build_kernel_arglist(self, kernel_fn, lowerer):
134136
argtype = kernel_fn.kernel_arg_types[arg_num]
135137
llvm_val = _getvar(lowerer, arg)
136138
if isinstance(argtype, DpnpNdArray):
139+
datamodel = dpex_dmm.lookup(arg_type)
137140
self.kernel_builder.build_array_arg(
138141
array_val=llvm_val,
142+
array_data_model=datamodel,
139143
array_rank=argtype.ndim,
140144
arg_list=self.args_list,
141145
args_ty_list=self.args_ty_list,
142146
arg_num=self.kernel_arg_num,
143147
)
144-
# FIXME: Get rid of magic constants
145-
self.kernel_arg_num += 5 + (2 * argtype.ndim)
148+
self.kernel_arg_num += datamodel.flattened_field_count
146149
else:
147150
if argtype == types.complex64:
148151
self.kernel_builder.build_complex_arg(

numba_dpex/core/runtime/_dpexrt_python.c

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
#include "_dbg_printer.h"
2323
#include "_queuestruct.h"
24-
#include "numba/_arraystruct.h"
24+
#include "_usmarraystruct.h"
2525

2626
// forward declarations
2727
static struct PyUSMArrayObject *PyUSMNdArray_ARRAYOBJ(PyObject *obj);
@@ -49,14 +49,14 @@ static NRT_MemInfo *DPEXRT_MemInfo_alloc(npy_intp size,
4949
size_t usm_type,
5050
const DPCTLSyclQueueRef qref);
5151
static void usmndarray_meminfo_dtor(void *ptr, size_t size, void *info);
52-
static PyObject *box_from_arystruct_parent(arystruct_t *arystruct,
52+
static PyObject *box_from_arystruct_parent(usmarystruct_t *arystruct,
5353
int ndim,
5454
PyArray_Descr *descr);
5555

5656
static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
57-
arystruct_t *arystruct);
57+
usmarystruct_t *arystruct);
5858
static PyObject *
59-
DPEXRT_sycl_usm_ndarray_to_python_acqref(arystruct_t *arystruct,
59+
DPEXRT_sycl_usm_ndarray_to_python_acqref(usmarystruct_t *arystruct,
6060
PyTypeObject *retty,
6161
int ndim,
6262
int writeable,
@@ -770,7 +770,7 @@ static npy_intp product_of_shape(npy_intp *shape, npy_intp ndim)
770770
* @return {return} Error code representing success (0) or failure (-1).
771771
*/
772772
static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
773-
arystruct_t *arystruct)
773+
usmarystruct_t *arystruct)
774774
{
775775
struct PyUSMArrayObject *arrayobj = NULL;
776776
int i = 0, j = 0, k = 0, ndim = 0, exp = 0;
@@ -827,6 +827,7 @@ static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
827827
}
828828

829829
arystruct->data = data;
830+
arystruct->sycl_queue = qref;
830831
arystruct->nitems = nitems;
831832
arystruct->itemsize = itemsize;
832833
arystruct->parent = obj;
@@ -892,7 +893,7 @@ static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
892893
* @return {return} A PyObject created from the arystruct_t->parent, if
893894
* the PyObject could not be created return NULL.
894895
*/
895-
static PyObject *box_from_arystruct_parent(arystruct_t *arystruct,
896+
static PyObject *box_from_arystruct_parent(usmarystruct_t *arystruct,
896897
int ndim,
897898
PyArray_Descr *descr)
898899
{
@@ -914,8 +915,10 @@ static PyObject *box_from_arystruct_parent(arystruct_t *arystruct,
914915
}
915916

916917
if ((void *)UsmNDArray_GetData(arrayobj) != arystruct->data) {
917-
DPEXRT_DEBUG(drt_debug_print("DPEXRT-DEBUG: Arrayobj cannot be boxed "
918-
"from parent as data pointer is NULL.\n"));
918+
DPEXRT_DEBUG(drt_debug_print(
919+
"DPEXRT-DEBUG: Arrayobj cannot be boxed "
920+
"from parent as data pointer in the arystruct is not the same as "
921+
"the data pointer in the parent object.\n"));
919922
return NULL;
920923
}
921924

@@ -978,7 +981,7 @@ static PyObject *box_from_arystruct_parent(arystruct_t *arystruct,
978981
*
979982
*/
980983
static PyObject *
981-
DPEXRT_sycl_usm_ndarray_to_python_acqref(arystruct_t *arystruct,
984+
DPEXRT_sycl_usm_ndarray_to_python_acqref(usmarystruct_t *arystruct,
982985
PyTypeObject *retty,
983986
int ndim,
984987
int writeable,
@@ -1094,8 +1097,7 @@ DPEXRT_sycl_usm_ndarray_to_python_acqref(arystruct_t *arystruct,
10941097
typenum = descr->type_num;
10951098
usm_ndarr_obj = UsmNDArray_MakeFromPtr(
10961099
ndim, shape, typenum, strides, (DPCTLSyclUSMRef)arystruct->data,
1097-
(DPCTLSyclQueueRef)miobj->meminfo->external_allocator->opaque_data, 0,
1098-
(PyObject *)miobj);
1100+
(DPCTLSyclQueueRef)arystruct->sycl_queue, 0, (PyObject *)miobj);
10991101

11001102
if (usm_ndarr_obj == NULL ||
11011103
!PyObject_TypeCheck(usm_ndarr_obj, &PyUSMArrayType))

numba_dpex/core/runtime/_queuestruct.h

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
1-
#ifndef NUMBA_DPEX_QUEUESTRUCT_H_
2-
#define NUMBA_DPEX_QUEUESTRUCT_H_
3-
/*
4-
* Fill in the *queuestruct* with information from the Numpy array *obj*.
5-
* *queuestruct*'s layout is defined in numba.targets.arrayobj (look
6-
* for the ArrayTemplate class).
7-
*/
1+
// SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
5+
//===----------------------------------------------------------------------===//
6+
///
7+
/// \file
8+
/// Defines the numba-dpex native representation for a dpctl.SyclQueue
9+
///
10+
//===----------------------------------------------------------------------===//
11+
12+
#pragma once
813

914
#include <Python.h>
1015

@@ -13,5 +18,3 @@ typedef struct
1318
PyObject *parent;
1419
void *queue_ref;
1520
} queuestruct_t;
16-
17-
#endif /* NUMBA_DPEX_QUEUESTRUCT_H_ */
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
5+
//===----------------------------------------------------------------------===//
6+
///
7+
/// \file
8+
/// Defines the numba-dpex native representation for a dpctl.tensor.usm_ndarray
9+
///
10+
//===----------------------------------------------------------------------===//
11+
12+
#pragma once
13+
14+
#include <Python.h>
15+
#include <numpy/npy_common.h>
16+
17+
typedef struct
18+
{
19+
void *meminfo;
20+
PyObject *parent;
21+
npy_intp nitems;
22+
npy_intp itemsize;
23+
void *data;
24+
void *sycl_queue;
25+
26+
npy_intp shape_and_strides[];
27+
} usmarystruct_t;

0 commit comments

Comments
 (0)