Skip to content

dlpack destructor "clobbers preexisting error indicator" #31271

@kcdodd

Description

@kcdodd

It honestly took some time to decide whether this is even jax's responsibility, but some recent additions to the cpython documentation (not yet on docs.python.org, but copied below) does describe that destructors need to work in the presence of yet unhandled errors (more alternatives also at Python Specification for DLPack). The dlpack capsule destructor clears error set by PyCapsule_GetPointer if the name changes, but this would also affect an existing error.

I would love to give a minimal reproducer here of what happens, but the only time I can get an observable error is via c-extension module functions, and even there the error message really depends on when exactly it gets deallocated (as a "local" variable, or as an "argument"). The interpreter seems to guard against this in regular python functions and all builtins that I have checked. But, here is a Cython-based reproducer:

# jax_dlpack.py
import traceback
from jax import numpy as jnp
from dummy_cython_mod import dummy_cfunc1, dummy_cfunc2

try:
  x = jnp.zeros(1)
  dummy_cfunc1(x.__dlpack__())
except Exception:
  traceback.print_exc()

try:
  x = jnp.zeros(1)
  dummy_cfunc2([x.__dlpack__()])
except Exception:
  traceback.print_exc()
Traceback (most recent call last):
  File "jax_dlpack.py", line 7, in <module>
    dummy_cfunc1(x.__dlpack__())
    ~~~~~~~~~~~~^^^^^^^^^^^^^^^^
SystemError: error return without exception set
Traceback (most recent call last):
  File "jax_dlpack.py", line 13, in <module>
    dummy_cfunc2([x.__dlpack__()])
    ~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
SystemError: <cyfunction dummy_cfunc2 at 0x7a22f7cbc400> returned NULL without setting an exception
# dummy_cython_mod.pyx
from cpython cimport PyCapsule_SetName

def dummy_cfunc1(pack):
  # ignore the fact that this would already be a memory leak even without any other errors.
  PyCapsule_SetName(pack, b'used_dltensor')
  assert False

def dummy_cfunc2(packs):
  pack = packs.pop()
  PyCapsule_SetName(pack, b'used_dltensor')
  assert False

Following is for reference:

// jax/jaxlib/dlpack.cc:320
                      DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(
                          PyCapsule_GetPointer(obj, kDlTensorCapsuleName));
                      if (dlmt) {
                        DLPackTensorDeleter(dlmt);
                      } else {
                        // The tensor has been deleted. Clear any error from
                        // PyCapsule_GetPointer.
                        PyErr_Clear();
                      }

python/cpython#28358

cpython/Doc/c-api/typeobj.rst:688

   If you may call functions that may set the error indicator, you must use
   :c:func:`PyErr_GetRaisedException` and :c:func:`PyErr_SetRaisedException`
   to ensure you don't clobber a preexisting error indicator (the deallocation
   could have occurred while processing a different error):

   .. code-block:: c

     static void
     foo_dealloc(foo_object *self)
     {
         PyObject *et, *ev, *etb;
         PyObject *exc = PyErr_GetRaisedException();
         ...
         PyErr_SetRaisedException(exc);
     }

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions