Skip to content

Commit 65c1d3b

Browse files
Fixed ref-counting of Python object temporaries in unboxing code
Ref-count of `arrayobj` was not decremented after use in case `dpnp.ndarray` object was being processed.
1 parent e36c979 commit 65c1d3b

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

numba_dpex/core/runtime/_dpexrt_python.c

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -744,14 +744,18 @@ static struct PyUSMArrayObject *PyUSMNdArray_ARRAYOBJ(PyObject *obj)
744744
DPEXRT_DEBUG(
745745
drt_debug_print("DPEXRT-DEBUG: usm array was passed directly\n"));
746746
arrayobj = obj;
747+
Py_INCREF(arrayobj);
747748
}
748749
else if (PyObject_HasAttrString(obj, "_array_obj")) {
750+
// PyObject_GetAttrString gives reference
749751
arrayobj = PyObject_GetAttrString(obj, "_array_obj");
750752

751753
if (!arrayobj)
752754
return NULL;
753-
if (!PyObject_TypeCheck(arrayobj, &PyUSMArrayType))
755+
if (!PyObject_TypeCheck(arrayobj, &PyUSMArrayType)) {
756+
Py_DECREF(arrayobj);
754757
return NULL;
758+
}
755759
}
756760

757761
struct PyUSMArrayObject *pyusmarrayobj =
@@ -803,17 +807,13 @@ static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
803807
PyGILState_STATE gstate;
804808
npy_intp itemsize = 0;
805809

806-
// Increment the ref count on obj to prevent CPython from garbage
807-
// collecting the array.
808-
// TODO: add extra description why do we need this
809-
Py_IncRef(obj);
810-
811810
DPEXRT_DEBUG(drt_debug_print(
812811
"DPEXRT-DEBUG: In DPEXRT_sycl_usm_ndarray_from_python at %s, line %d\n",
813812
__FILE__, __LINE__));
814813

815814
// Check if the PyObject obj has an _array_obj attribute that is of
816815
// dpctl.tensor.usm_ndarray type.
816+
// arrayobj is a new reference, reference of obj is borrowed
817817
if (!(arrayobj = PyUSMNdArray_ARRAYOBJ(obj))) {
818818
DPEXRT_DEBUG(drt_debug_print(
819819
"DPEXRT-ERROR: PyUSMNdArray_ARRAYOBJ check failed at %s, line %d\n",
@@ -832,6 +832,7 @@ static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
832832
data = (void *)UsmNDArray_GetData(arrayobj);
833833
nitems = product_of_shape(shape, ndim);
834834
itemsize = (npy_intp)UsmNDArray_GetElementSize(arrayobj);
835+
835836
if (!(qref = UsmNDArray_GetQueueRef(arrayobj))) {
836837
DPEXRT_DEBUG(drt_debug_print(
837838
"DPEXRT-ERROR: UsmNDArray_GetQueueRef returned NULL at "
@@ -841,7 +842,7 @@ static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
841842
}
842843

843844
if (!(arystruct->meminfo = NRT_MemInfo_new_from_usmndarray(
844-
obj, data, nitems, itemsize, qref)))
845+
arrayobj, data, nitems, itemsize, qref)))
845846
{
846847
DPEXRT_DEBUG(drt_debug_print(
847848
"DPEXRT-ERROR: NRT_MemInfo_new_from_usmndarray failed "
@@ -854,7 +855,7 @@ static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
854855
arystruct->sycl_queue = qref;
855856
arystruct->nitems = nitems;
856857
arystruct->itemsize = itemsize;
857-
arystruct->parent = obj;
858+
arystruct->parent = arrayobj;
858859

859860
p = arystruct->shape_and_strides;
860861

@@ -906,7 +907,7 @@ static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
906907
__FILE__, __LINE__));
907908
gstate = PyGILState_Ensure();
908909
// decref the python object
909-
Py_DECREF(obj);
910+
Py_XDECREF(arrayobj);
910911
// release the GIL
911912
PyGILState_Release(gstate);
912913

0 commit comments

Comments
 (0)