Skip to content

Commit 0364e76

Browse files
committed
Use nrt api to allocate meminfo object
1 parent 574daab commit 0364e76

File tree

2 files changed

+30
-14
lines changed

2 files changed

+30
-14
lines changed

numba_dpex/core/runtime/_dpexrt_python.c

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,20 +45,23 @@ static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi,
4545
bool value_is_float,
4646
int64_t value,
4747
const DPCTLSyclQueueRef qref);
48-
static NRT_MemInfo *NRT_MemInfo_new_from_usmndarray(PyObject *ndarrobj,
48+
static NRT_MemInfo *NRT_MemInfo_new_from_usmndarray(NRT_api_functions *nrt,
49+
PyObject *ndarrobj,
4950
void *data,
5051
npy_intp nitems,
5152
npy_intp itemsize,
5253
DPCTLSyclQueueRef qref);
53-
static NRT_MemInfo *DPEXRT_MemInfo_alloc(npy_intp size,
54+
static NRT_MemInfo *DPEXRT_MemInfo_alloc(NRT_api_functions *nrt,
55+
npy_intp size,
5456
size_t usm_type,
5557
const DPCTLSyclQueueRef qref);
5658
static void usmndarray_meminfo_dtor(void *ptr, size_t size, void *info);
5759
static PyObject *box_from_arystruct_parent(usmarystruct_t *arystruct,
5860
int ndim,
5961
PyArray_Descr *descr);
6062

61-
static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
63+
static int DPEXRT_sycl_usm_ndarray_from_python(NRT_api_functions *nrt,
64+
PyObject *obj,
6265
usmarystruct_t *arystruct);
6366
static PyObject *
6467
DPEXRT_sycl_usm_ndarray_to_python_acqref(usmarystruct_t *arystruct,
@@ -336,6 +339,11 @@ NRT_ExternalAllocator_new_for_usm(DPCTLSyclQueueRef qref, size_t usm_type)
336339
static void usmndarray_meminfo_dtor(void *ptr, size_t size, void *info)
337340
{
338341
MemInfoDtorInfo *mi_dtor_info = NULL;
342+
// Warning: we are destructing sycl memory. MI destructor is called
343+
// separately by numba.
344+
DPEXRT_DEBUG(drt_debug_print("DPEXRT-DEBUG: Call to "
345+
"usmndarray_meminfo_dtor at %s, line %d\n",
346+
__FILE__, __LINE__));
339347

340348
// Sanity-check to make sure the mi_dtor_info is an actual pointer.
341349
if (!(mi_dtor_info = (MemInfoDtorInfo *)info)) {
@@ -416,7 +424,8 @@ static MemInfoDtorInfo *MemInfoDtorInfo_new(NRT_MemInfo *mi, PyObject *owner)
416424
* of the dpnp.ndarray was allocated.
417425
* @return {return} A new NRT_MemInfo object
418426
*/
419-
static NRT_MemInfo *NRT_MemInfo_new_from_usmndarray(PyObject *ndarrobj,
427+
static NRT_MemInfo *NRT_MemInfo_new_from_usmndarray(NRT_api_functions *nrt,
428+
PyObject *ndarrobj,
420429
void *data,
421430
npy_intp nitems,
422431
npy_intp itemsize,
@@ -427,8 +436,9 @@ static NRT_MemInfo *NRT_MemInfo_new_from_usmndarray(PyObject *ndarrobj,
427436
MemInfoDtorInfo *midtor_info = NULL;
428437
DPCTLSyclContextRef cref = NULL;
429438

430-
// Allocate a new NRT_MemInfo object
431-
if (!(mi = (NRT_MemInfo *)malloc(sizeof(NRT_MemInfo)))) {
439+
// Allocate a new NRT_MemInfo object. By passing 0 we are just allocating
440+
// MemInfo and not the `data` that the MemInfo object manages.
441+
if (!(mi = (NRT_MemInfo *)nrt->allocate(0))) {
432442
DPEXRT_DEBUG(drt_debug_print(
433443
"DPEXRT-ERROR: Could not allocate a new NRT_MemInfo "
434444
"object at %s, line %d\n",
@@ -505,7 +515,8 @@ static NRT_MemInfo *NRT_MemInfo_new_from_usmndarray(PyObject *ndarrobj,
505515
* @return {return} A new NRT_MemInfo object, NULL if no NRT_MemInfo
506516
* object could be created.
507517
*/
508-
static NRT_MemInfo *DPEXRT_MemInfo_alloc(npy_intp size,
518+
static NRT_MemInfo *DPEXRT_MemInfo_alloc(NRT_api_functions *nrt,
519+
npy_intp size,
509520
size_t usm_type,
510521
const DPCTLSyclQueueRef qref)
511522
{
@@ -517,7 +528,7 @@ static NRT_MemInfo *DPEXRT_MemInfo_alloc(npy_intp size,
517528
"DPEXRT-DEBUG: Inside DPEXRT_MemInfo_alloc %s, line %d\n", __FILE__,
518529
__LINE__));
519530
// Allocate a new NRT_MemInfo object
520-
if (!(mi = (NRT_MemInfo *)malloc(sizeof(NRT_MemInfo)))) {
531+
if (!(mi = (NRT_MemInfo *)nrt->allocate(0))) {
521532
DPEXRT_DEBUG(drt_debug_print(
522533
"DPEXRT-ERROR: Could not allocate a new NRT_MemInfo object.\n"));
523534
goto error;
@@ -795,7 +806,8 @@ static npy_intp product_of_shape(npy_intp *shape, npy_intp ndim)
795806
* instance of a dpnp.ndarray
796807
* @return {return} Error code representing success (0) or failure (-1).
797808
*/
798-
static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
809+
static int DPEXRT_sycl_usm_ndarray_from_python(NRT_api_functions *nrt,
810+
PyObject *obj,
799811
usmarystruct_t *arystruct)
800812
{
801813
struct PyUSMArrayObject *arrayobj = NULL;
@@ -842,7 +854,7 @@ static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
842854
}
843855

844856
if (!(arystruct->meminfo = NRT_MemInfo_new_from_usmndarray(
845-
obj, data, nitems, itemsize, qref)))
857+
nrt, obj, data, nitems, itemsize, qref)))
846858
{
847859
DPEXRT_DEBUG(drt_debug_print(
848860
"DPEXRT-ERROR: NRT_MemInfo_new_from_usmndarray failed "

numba_dpex/core/runtime/context.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,14 @@ def meminfo_alloc_unchecked(self, builder, size, usm_type, queue_ref):
8989
mod = builder.module
9090
u64 = llvmir.IntType(64)
9191
fnty = llvmir.FunctionType(
92-
cgutils.voidptr_t, [cgutils.intp_t, u64, cgutils.voidptr_t]
92+
cgutils.voidptr_t,
93+
[cgutils.voidptr_t, cgutils.intp_t, u64, cgutils.voidptr_t],
9394
)
9495
fn = cgutils.get_or_insert_function(mod, fnty, "DPEXRT_MemInfo_alloc")
9596
fn.return_value.add_attribute("noalias")
97+
nrt_api = self._context.nrt.get_nrt_api(builder)
9698

97-
ret = builder.call(fn, [size, usm_type, queue_ref])
99+
ret = builder.call(fn, [nrt_api, size, usm_type, queue_ref])
98100

99101
return ret
100102

@@ -168,13 +170,15 @@ def arraystruct_from_python(self, pyapi, obj, ptr):
168170
169171
"""
170172
fnty = llvmir.FunctionType(
171-
llvmir.IntType(32), [pyapi.pyobj, pyapi.voidptr]
173+
llvmir.IntType(32), [pyapi.voidptr, pyapi.pyobj, pyapi.voidptr]
172174
)
175+
nrt_api = self._context.nrt.get_nrt_api(pyapi.builder)
173176
fn = pyapi._get_function(fnty, "DPEXRT_sycl_usm_ndarray_from_python")
174177
fn.args[0].add_attribute("nocapture")
175178
fn.args[1].add_attribute("nocapture")
179+
fn.args[2].add_attribute("nocapture")
176180

177-
self.error = pyapi.builder.call(fn, (obj, ptr))
181+
self.error = pyapi.builder.call(fn, (nrt_api, obj, ptr))
178182

179183
return self.error
180184

0 commit comments

Comments
 (0)