Skip to content

Commit a030c85

Browse files
author
Diptorup Deb
authored
Merge pull request #1190 from IntelPython/fix/sycl_queue_meminfo
Fix lifetime management for sycl queue
2 parents af314f5 + 8df9ea1 commit a030c85

File tree

4 files changed

+34
-14
lines changed

4 files changed

+34
-14
lines changed

numba_dpex/core/datamodel/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,8 @@ class SyclQueueModel(StructModel):
145145
def __init__(self, dmm, fe_type):
146146
members = [
147147
(
148-
"parent",
149-
types.CPointer(types.int8),
148+
"meminfo",
149+
types.MemInfoPointer(types.pyobject),
150150
),
151151
(
152152
"queue_ref",

numba_dpex/core/runtime/_dpexrt_python.c

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,14 @@ DPEXRT_sycl_usm_ndarray_to_python_acqref(usmarystruct_t *arystruct,
6464
int ndim,
6565
int writeable,
6666
PyArray_Descr *descr);
67-
static int DPEXRT_sycl_queue_from_python(PyObject *obj,
67+
static int DPEXRT_sycl_queue_from_python(NRT_api_functions *nrt,
68+
PyObject *obj,
6869
queuestruct_t *queue_struct);
6970
static int DPEXRT_sycl_event_from_python(NRT_api_functions *nrt,
7071
PyObject *obj,
7172
eventstruct_t *event_struct);
72-
static PyObject *DPEXRT_sycl_queue_to_python(queuestruct_t *queuestruct);
73+
static PyObject *DPEXRT_sycl_queue_to_python(NRT_api_functions *nrt,
74+
queuestruct_t *queuestruct);
7375
static PyObject *DPEXRT_sycl_event_to_python(NRT_api_functions *nrt,
7476
eventstruct_t *eventstruct);
7577

@@ -1216,7 +1218,8 @@ DPEXRT_sycl_usm_ndarray_to_python_acqref(usmarystruct_t *arystruct,
12161218
* represent a dpctl.SyclQueue inside Numba.
12171219
* @return {return} Return code indicating success (0) or failure (-1).
12181220
*/
1219-
static int DPEXRT_sycl_queue_from_python(PyObject *obj,
1221+
static int DPEXRT_sycl_queue_from_python(NRT_api_functions *nrt,
1222+
PyObject *obj,
12201223
queuestruct_t *queue_struct)
12211224
{
12221225
struct PySyclQueueObject *queue_obj = NULL;
@@ -1246,7 +1249,13 @@ static int DPEXRT_sycl_queue_from_python(PyObject *obj,
12461249
DPCTLDeviceMgr_GetDeviceInfoStr(device_ref));
12471250
DPCTLDevice_Delete(device_ref););
12481251

1249-
queue_struct->parent = obj;
1252+
// We are doing incref here to ensure python does not release the object
1253+
// while NRT references it. Coresponding decref is called by NRT in
1254+
// NRT_MemInfo_pyobject_dtor once there is no reference to this object by
1255+
// the code managed by NRT.
1256+
Py_INCREF(queue_obj);
1257+
queue_struct->meminfo =
1258+
nrt->manage_memory(queue_obj, NRT_MemInfo_pyobject_dtor);
12501259
queue_struct->queue_ref = queue_ref;
12511260

12521261
return 0;
@@ -1275,11 +1284,12 @@ static int DPEXRT_sycl_queue_from_python(PyObject *obj,
12751284
* @return {return} A PyObject created from the queuestruct->parent, if
12761285
* the PyObject could not be created return NULL.
12771286
*/
1278-
static PyObject *DPEXRT_sycl_queue_to_python(queuestruct_t *queuestruct)
1287+
static PyObject *DPEXRT_sycl_queue_to_python(NRT_api_functions *nrt,
1288+
queuestruct_t *queuestruct)
12791289
{
12801290
PyObject *orig_queue = NULL;
12811291

1282-
orig_queue = queuestruct->parent;
1292+
orig_queue = nrt->get_data(queuestruct->meminfo);
12831293
// FIXME: Better error checking is needed to enforce the boxing of the queue
12841294
// object. For now, only the minimal is done as the returning of SyclQueue
12851295
// from a dpjit function should not be a used often and the dpctl C API for
@@ -1291,9 +1301,13 @@ static PyObject *DPEXRT_sycl_queue_to_python(queuestruct_t *queuestruct)
12911301
return NULL;
12921302
}
12931303

1304+
// TODO: is there any way to release meminfo without calling dtor so we dont
1305+
// call incref, decref one after another.
12941306
// We need to increase reference count because we are returning new
12951307
// reference to the same queue.
12961308
Py_INCREF(orig_queue);
1309+
// We need to release meminfo since we are taking ownership back.
1310+
nrt->release(queuestruct->meminfo);
12971311

12981312
return orig_queue;
12991313
}

numba_dpex/core/runtime/_queuestruct.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111

1212
#pragma once
1313

14-
#include <Python.h>
14+
#include "numba/core/runtime/nrt_external.h"
1515

1616
typedef struct
1717
{
18-
PyObject *parent;
18+
NRT_MemInfo *meminfo;
1919
void *queue_ref;
2020
} queuestruct_t;

numba_dpex/core/runtime/context.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,26 +181,32 @@ def arraystruct_from_python(self, pyapi, obj, ptr):
181181
def queuestruct_from_python(self, pyapi, obj, ptr):
182182
"""Calls the c function DPEXRT_sycl_queue_from_python"""
183183
fnty = llvmir.FunctionType(
184-
llvmir.IntType(32), [pyapi.pyobj, pyapi.voidptr]
184+
llvmir.IntType(32), [pyapi.voidptr, pyapi.pyobj, pyapi.voidptr]
185185
)
186+
nrt_api = self._context.nrt.get_nrt_api(pyapi.builder)
186187

187188
fn = pyapi._get_function(fnty, "DPEXRT_sycl_queue_from_python")
188189
fn.args[0].add_attribute("nocapture")
189190
fn.args[1].add_attribute("nocapture")
191+
fn.args[2].add_attribute("nocapture")
190192

191-
self.error = pyapi.builder.call(fn, (obj, ptr))
193+
self.error = pyapi.builder.call(fn, (nrt_api, obj, ptr))
192194
return self.error
193195

194196
def queuestruct_to_python(self, pyapi, val):
195197
"""Calls the c function DPEXRT_sycl_queue_to_python"""
196198

197-
fnty = llvmir.FunctionType(pyapi.pyobj, [pyapi.voidptr])
199+
fnty = llvmir.FunctionType(pyapi.pyobj, [pyapi.voidptr, pyapi.voidptr])
200+
nrt_api = self._context.nrt.get_nrt_api(pyapi.builder)
198201

199202
fn = pyapi._get_function(fnty, "DPEXRT_sycl_queue_to_python")
200203
fn.args[0].add_attribute("nocapture")
204+
fn.args[1].add_attribute("nocapture")
205+
201206
qptr = cgutils.alloca_once_value(pyapi.builder, val)
202207
ptr = pyapi.builder.bitcast(qptr, pyapi.voidptr)
203-
self.error = pyapi.builder.call(fn, [ptr])
208+
209+
self.error = pyapi.builder.call(fn, [nrt_api, ptr])
204210

205211
return self.error
206212

0 commit comments

Comments
 (0)