@@ -744,14 +744,18 @@ static struct PyUSMArrayObject *PyUSMNdArray_ARRAYOBJ(PyObject *obj)
744
744
DPEXRT_DEBUG (
745
745
drt_debug_print ("DPEXRT-DEBUG: usm array was passed directly\n" ));
746
746
arrayobj = obj ;
747
+ Py_INCREF (arrayobj );
747
748
}
748
749
else if (PyObject_HasAttrString (obj , "_array_obj" )) {
750
+ // PyObject_GetAttrString gives reference
749
751
arrayobj = PyObject_GetAttrString (obj , "_array_obj" );
750
752
751
753
if (!arrayobj )
752
754
return NULL ;
753
- if (!PyObject_TypeCheck (arrayobj , & PyUSMArrayType ))
755
+ if (!PyObject_TypeCheck (arrayobj , & PyUSMArrayType )) {
756
+ Py_DECREF (arrayobj );
754
757
return NULL ;
758
+ }
755
759
}
756
760
757
761
struct PyUSMArrayObject * pyusmarrayobj =
@@ -803,17 +807,13 @@ static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
803
807
PyGILState_STATE gstate ;
804
808
npy_intp itemsize = 0 ;
805
809
806
- // Increment the ref count on obj to prevent CPython from garbage
807
- // collecting the array.
808
- // TODO: add extra description why do we need this
809
- Py_IncRef (obj );
810
-
811
810
DPEXRT_DEBUG (drt_debug_print (
812
811
"DPEXRT-DEBUG: In DPEXRT_sycl_usm_ndarray_from_python at %s, line %d\n" ,
813
812
__FILE__ , __LINE__ ));
814
813
815
814
// Check if the PyObject obj has an _array_obj attribute that is of
816
815
// dpctl.tensor.usm_ndarray type.
816
+ // arrayobj is a new reference, reference of obj is borrowed
817
817
if (!(arrayobj = PyUSMNdArray_ARRAYOBJ (obj ))) {
818
818
DPEXRT_DEBUG (drt_debug_print (
819
819
"DPEXRT-ERROR: PyUSMNdArray_ARRAYOBJ check failed at %s, line %d\n" ,
@@ -832,6 +832,7 @@ static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
832
832
data = (void * )UsmNDArray_GetData (arrayobj );
833
833
nitems = product_of_shape (shape , ndim );
834
834
itemsize = (npy_intp )UsmNDArray_GetElementSize (arrayobj );
835
+
835
836
if (!(qref = UsmNDArray_GetQueueRef (arrayobj ))) {
836
837
DPEXRT_DEBUG (drt_debug_print (
837
838
"DPEXRT-ERROR: UsmNDArray_GetQueueRef returned NULL at "
@@ -841,7 +842,7 @@ static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
841
842
}
842
843
843
844
if (!(arystruct -> meminfo = NRT_MemInfo_new_from_usmndarray (
844
- obj , data , nitems , itemsize , qref )))
845
+ arrayobj , data , nitems , itemsize , qref )))
845
846
{
846
847
DPEXRT_DEBUG (drt_debug_print (
847
848
"DPEXRT-ERROR: NRT_MemInfo_new_from_usmndarray failed "
@@ -854,7 +855,7 @@ static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
854
855
arystruct -> sycl_queue = qref ;
855
856
arystruct -> nitems = nitems ;
856
857
arystruct -> itemsize = itemsize ;
857
- arystruct -> parent = obj ;
858
+ arystruct -> parent = arrayobj ;
858
859
859
860
p = arystruct -> shape_and_strides ;
860
861
@@ -906,7 +907,7 @@ static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
906
907
__FILE__ , __LINE__ ));
907
908
gstate = PyGILState_Ensure ();
908
909
// decref the python object
909
- Py_DECREF ( obj );
910
+ Py_XDECREF ( arrayobj );
910
911
// release the GIL
911
912
PyGILState_Release (gstate );
912
913
0 commit comments