Skip to content

Commit af314f5

Browse files
authored
Merge pull request #1188 from IntelPython/fix/sycl_event_meminfo
Fix lifetime management for sycl event
2 parents 067cbbb + d4d1ff0 commit af314f5

File tree

7 files changed

+53
-16
lines changed

7 files changed

+53
-16
lines changed

numba_dpex/core/datamodel/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,8 @@ class SyclEventModel(StructModel):
172172
def __init__(self, dmm, fe_type):
173173
members = [
174174
(
175-
"parent",
176-
types.CPointer(types.int8),
175+
"meminfo",
176+
types.MemInfoPointer(types.pyobject),
177177
),
178178
(
179179
"event_ref",

numba_dpex/core/runtime/_dpexrt_python.c

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
#include "_queuestruct.h"
2525
#include "_usmarraystruct.h"
2626

27+
#include "numba/core/runtime/nrt_external.h"
28+
2729
// forward declarations
2830
static struct PyUSMArrayObject *PyUSMNdArray_ARRAYOBJ(PyObject *obj);
2931
static npy_intp product_of_shape(npy_intp *shape, npy_intp ndim);
@@ -64,9 +66,12 @@ DPEXRT_sycl_usm_ndarray_to_python_acqref(usmarystruct_t *arystruct,
6466
PyArray_Descr *descr);
6567
static int DPEXRT_sycl_queue_from_python(PyObject *obj,
6668
queuestruct_t *queue_struct);
67-
static int DPEXRT_sycl_event_from_python(PyObject *obj,
69+
static int DPEXRT_sycl_event_from_python(NRT_api_functions *nrt,
70+
PyObject *obj,
6871
eventstruct_t *event_struct);
6972
static PyObject *DPEXRT_sycl_queue_to_python(queuestruct_t *queuestruct);
73+
static PyObject *DPEXRT_sycl_event_to_python(NRT_api_functions *nrt,
74+
eventstruct_t *eventstruct);
7075

7176
/** An NRT_external_malloc_func implementation using DPCTLmalloc_device.
7277
*
@@ -1306,7 +1311,8 @@ static PyObject *DPEXRT_sycl_queue_to_python(queuestruct_t *queuestruct)
13061311
* represent a dpctl.SyclEvent inside Numba.
13071312
* @return {return} Return code indicating success (0) or failure (-1).
13081313
*/
1309-
static int DPEXRT_sycl_event_from_python(PyObject *obj,
1314+
static int DPEXRT_sycl_event_from_python(NRT_api_functions *nrt,
1315+
PyObject *obj,
13101316
eventstruct_t *event_struct)
13111317
{
13121318
struct PySyclEventObject *event_obj = NULL;
@@ -1328,7 +1334,13 @@ static int DPEXRT_sycl_event_from_python(PyObject *obj,
13281334
goto error;
13291335
}
13301336

1331-
event_struct->parent = obj;
1337+
// We are doing incref here to ensure python does not release the object
1338+
// while NRT references it. Coresponding decref is called by NRT in
1339+
// NRT_MemInfo_pyobject_dtor once there is no reference to this object by
1340+
// the code managed by NRT.
1341+
Py_INCREF(event_obj);
1342+
event_struct->meminfo =
1343+
nrt->manage_memory(event_obj, NRT_MemInfo_pyobject_dtor);
13321344
event_struct->event_ref = event_ref;
13331345

13341346
return 0;
@@ -1355,12 +1367,13 @@ static int DPEXRT_sycl_event_from_python(PyObject *obj,
13551367
* @return {return} A PyObject created from the eventstruct->parent, if
13561368
* the PyObject could not be created return NULL.
13571369
*/
1358-
static PyObject *DPEXRT_sycl_event_to_python(eventstruct_t *eventstruct)
1370+
static PyObject *DPEXRT_sycl_event_to_python(NRT_api_functions *nrt,
1371+
eventstruct_t *eventstruct)
13591372
{
13601373
PyObject *orig_event = NULL;
13611374
PyGILState_STATE gstate;
13621375

1363-
orig_event = eventstruct->parent;
1376+
orig_event = nrt->get_data(eventstruct->meminfo);
13641377
// FIXME: Better error checking is needed to enforce the boxing of the event
13651378
// object. For now, only the minimal is done as the returning of SyclEvent
13661379
// from a dpjit function should not be a used often and the dpctl C API for
@@ -1375,9 +1388,13 @@ static PyObject *DPEXRT_sycl_event_to_python(eventstruct_t *eventstruct)
13751388
DPEXRT_DEBUG(
13761389
drt_debug_print("DPEXRT-DEBUG: In DPEXRT_sycl_event_to_python.\n"););
13771390

1391+
// TODO: is there any way to release meminfo without calling dtor so we dont
1392+
// call incref, decref one after another.
13781393
// We need to increase reference count because we are returning new
13791394
// reference to the same event.
13801395
Py_INCREF(orig_event);
1396+
// We need to release meminfo since we are taking ownership back.
1397+
nrt->release(eventstruct->meminfo);
13811398

13821399
return orig_event;
13831400
}

numba_dpex/core/runtime/_eventstruct.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111

1212
#pragma once
1313

14-
#include <Python.h>
14+
#include "numba/core/runtime/nrt_external.h"
1515

1616
typedef struct
1717
{
18-
PyObject *parent;
18+
NRT_MemInfo *meminfo;
1919
void *event_ref;
2020
} eventstruct_t;

numba_dpex/core/runtime/_nrt_helper.c

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,15 @@ void NRT_MemInfo_destroy(NRT_MemInfo *mi)
134134
TheMSys.stats.mi_free++;
135135
}
136136
}
137+
138+
void NRT_MemInfo_pyobject_dtor(void *data)
139+
{
140+
PyGILState_STATE gstate;
141+
PyObject *ownerobj = data;
142+
143+
gstate = PyGILState_Ensure(); /* ensure the GIL */
144+
Py_DECREF(data); /* release the python object */
145+
PyGILState_Release(gstate); /* release the GIL */
146+
147+
DPEXRT_DEBUG(drt_debug_print("DPEXRT-DEBUG: pyobject destructor\n"););
148+
}

numba_dpex/core/runtime/_nrt_helper.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,6 @@ size_t NRT_MemInfo_refcount(NRT_MemInfo *mi);
1919
void NRT_Free(void *ptr);
2020
void NRT_dealloc(NRT_MemInfo *mi);
2121
void NRT_MemInfo_destroy(NRT_MemInfo *mi);
22+
void NRT_MemInfo_pyobject_dtor(void *data);
2223

2324
#endif /* _NRT_HELPER_H_ */

numba_dpex/core/runtime/context.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import functools
66

7+
import numba.core.unsafe.nrt
78
from llvmlite import ir as llvmir
89
from numba.core import cgutils, types
910

@@ -206,26 +207,32 @@ def queuestruct_to_python(self, pyapi, val):
206207
def eventstruct_from_python(self, pyapi, obj, ptr):
207208
"""Calls the c function DPEXRT_sycl_event_from_python"""
208209
fnty = llvmir.FunctionType(
209-
llvmir.IntType(32), [pyapi.pyobj, pyapi.voidptr]
210+
llvmir.IntType(32), [pyapi.voidptr, pyapi.pyobj, pyapi.voidptr]
210211
)
212+
nrt_api = self._context.nrt.get_nrt_api(pyapi.builder)
211213

212214
fn = pyapi._get_function(fnty, "DPEXRT_sycl_event_from_python")
213215
fn.args[0].add_attribute("nocapture")
214216
fn.args[1].add_attribute("nocapture")
217+
fn.args[2].add_attribute("nocapture")
215218

216-
self.error = pyapi.builder.call(fn, (obj, ptr))
219+
self.error = pyapi.builder.call(fn, (nrt_api, obj, ptr))
217220
return self.error
218221

219222
def eventstruct_to_python(self, pyapi, val):
220223
"""Calls the c function DPEXRT_sycl_event_to_python"""
221224

222-
fnty = llvmir.FunctionType(pyapi.pyobj, [pyapi.voidptr])
225+
fnty = llvmir.FunctionType(pyapi.pyobj, [pyapi.voidptr, pyapi.voidptr])
226+
nrt_api = self._context.nrt.get_nrt_api(pyapi.builder)
223227

224228
fn = pyapi._get_function(fnty, "DPEXRT_sycl_event_to_python")
225229
fn.args[0].add_attribute("nocapture")
230+
fn.args[1].add_attribute("nocapture")
231+
226232
qptr = cgutils.alloca_once_value(pyapi.builder, val)
227233
ptr = pyapi.builder.bitcast(qptr, pyapi.voidptr)
228-
self.error = pyapi.builder.call(fn, [ptr])
234+
235+
self.error = pyapi.builder.call(fn, [nrt_api, ptr])
229236

230237
return self.error
231238

numba_dpex/dpctl_iface/libsyclinterface_bindings.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def dpctl_event_wait(builder: llvmir.IRBuilder, *args):
7171
mod = builder.module
7272
fn = _build_dpctl_function(
7373
llvm_module=mod,
74-
return_ty=cgutils.voidptr_t,
74+
return_ty=llvmir.VoidType(),
7575
arg_list=[cgutils.voidptr_t],
7676
func_name="DPCTLEvent_Wait",
7777
)
@@ -85,7 +85,7 @@ def dpctl_event_delete(builder: llvmir.IRBuilder, *args):
8585
mod = builder.module
8686
fn = _build_dpctl_function(
8787
llvm_module=mod,
88-
return_ty=cgutils.voidptr_t,
88+
return_ty=llvmir.VoidType(),
8989
arg_list=[cgutils.voidptr_t],
9090
func_name="DPCTLEvent_Delete",
9191
)
@@ -99,7 +99,7 @@ def dpctl_queue_delete(builder: llvmir.IRBuilder, *args):
9999
mod = builder.module
100100
fn = _build_dpctl_function(
101101
llvm_module=mod,
102-
return_ty=cgutils.voidptr_t,
102+
return_ty=llvmir.VoidType(),
103103
arg_list=[cgutils.voidptr_t],
104104
func_name="DPCTLQueue_Delete",
105105
)

0 commit comments

Comments
 (0)