Skip to content

Commit e41b9c1

Browse files
authored
BUG: PyDataMem_SetHandler check capsule name (numpy#26529)
Closes numpy#26137. Tests capsule name and sets PyErr if not valid: if (!PyCapsule_IsValid(handler, MEM_HANDLER_CAPSULE_NAME)) { PyErr_SetString(PyExc_ValueError, "Capsule must be named 'mem_handler'") return NULL; }
1 parent a03e0ef commit e41b9c1

File tree

5 files changed

+33
-7
lines changed

5 files changed

+33
-7
lines changed

numpy/_core/src/multiarray/alloc.c

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,8 @@ NPY_NO_EXPORT void *
354354
PyDataMem_UserNEW(size_t size, PyObject *mem_handler)
355355
{
356356
void *result;
357-
PyDataMem_Handler *handler = (PyDataMem_Handler *) PyCapsule_GetPointer(mem_handler, "mem_handler");
357+
PyDataMem_Handler *handler = (PyDataMem_Handler *) PyCapsule_GetPointer(
358+
mem_handler, MEM_HANDLER_CAPSULE_NAME);
358359
if (handler == NULL) {
359360
return NULL;
360361
}
@@ -368,7 +369,8 @@ NPY_NO_EXPORT void *
368369
PyDataMem_UserNEW_ZEROED(size_t nmemb, size_t size, PyObject *mem_handler)
369370
{
370371
void *result;
371-
PyDataMem_Handler *handler = (PyDataMem_Handler *) PyCapsule_GetPointer(mem_handler, "mem_handler");
372+
PyDataMem_Handler *handler = (PyDataMem_Handler *) PyCapsule_GetPointer(
373+
mem_handler, MEM_HANDLER_CAPSULE_NAME);
372374
if (handler == NULL) {
373375
return NULL;
374376
}
@@ -381,7 +383,8 @@ PyDataMem_UserNEW_ZEROED(size_t nmemb, size_t size, PyObject *mem_handler)
381383
NPY_NO_EXPORT void
382384
PyDataMem_UserFREE(void *ptr, size_t size, PyObject *mem_handler)
383385
{
384-
PyDataMem_Handler *handler = (PyDataMem_Handler *) PyCapsule_GetPointer(mem_handler, "mem_handler");
386+
PyDataMem_Handler *handler = (PyDataMem_Handler *) PyCapsule_GetPointer(
387+
mem_handler, MEM_HANDLER_CAPSULE_NAME);
385388
if (handler == NULL) {
386389
WARN_NO_RETURN(PyExc_RuntimeWarning,
387390
"Could not get pointer to 'mem_handler' from PyCapsule");
@@ -395,7 +398,8 @@ NPY_NO_EXPORT void *
395398
PyDataMem_UserRENEW(void *ptr, size_t size, PyObject *mem_handler)
396399
{
397400
void *result;
398-
PyDataMem_Handler *handler = (PyDataMem_Handler *) PyCapsule_GetPointer(mem_handler, "mem_handler");
401+
PyDataMem_Handler *handler = (PyDataMem_Handler *) PyCapsule_GetPointer(
402+
mem_handler, MEM_HANDLER_CAPSULE_NAME);
399403
if (handler == NULL) {
400404
return NULL;
401405
}
@@ -427,6 +431,10 @@ PyDataMem_SetHandler(PyObject *handler)
427431
if (handler == NULL) {
428432
handler = PyDataMem_DefaultHandler;
429433
}
434+
if (!PyCapsule_IsValid(handler, MEM_HANDLER_CAPSULE_NAME)) {
435+
PyErr_SetString(PyExc_ValueError, "Capsule must be named 'mem_handler'");
436+
return NULL;
437+
}
430438
token = PyContextVar_Set(current_handler, handler);
431439
if (token == NULL) {
432440
Py_DECREF(old_handler);
@@ -477,7 +485,8 @@ get_handler_name(PyObject *NPY_UNUSED(self), PyObject *args)
477485
return NULL;
478486
}
479487
}
480-
handler = (PyDataMem_Handler *) PyCapsule_GetPointer(mem_handler, "mem_handler");
488+
handler = (PyDataMem_Handler *) PyCapsule_GetPointer(
489+
mem_handler, MEM_HANDLER_CAPSULE_NAME);
481490
if (handler == NULL) {
482491
Py_DECREF(mem_handler);
483492
return NULL;
@@ -514,7 +523,8 @@ get_handler_version(PyObject *NPY_UNUSED(self), PyObject *args)
514523
return NULL;
515524
}
516525
}
517-
handler = (PyDataMem_Handler *) PyCapsule_GetPointer(mem_handler, "mem_handler");
526+
handler = (PyDataMem_Handler *) PyCapsule_GetPointer(
527+
mem_handler, MEM_HANDLER_CAPSULE_NAME);
518528
if (handler == NULL) {
519529
Py_DECREF(mem_handler);
520530
return NULL;

numpy/_core/src/multiarray/alloc.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "numpy/ndarraytypes.h"
77

88
#define NPY_TRACE_DOMAIN 389047
9+
#define MEM_HANDLER_CAPSULE_NAME "mem_handler"
910

1011
NPY_NO_EXPORT PyObject *
1112
_get_madvise_hugepage(PyObject *NPY_UNUSED(self), PyObject *NPY_UNUSED(args));

numpy/_core/src/multiarray/multiarraymodule.c

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5230,7 +5230,8 @@ PyMODINIT_FUNC PyInit__multiarray_umath(void) {
52305230
/*
52315231
* Initialize the default PyDataMem_Handler capsule singleton.
52325232
*/
5233-
PyDataMem_DefaultHandler = PyCapsule_New(&default_handler, "mem_handler", NULL);
5233+
PyDataMem_DefaultHandler = PyCapsule_New(
5234+
&default_handler, MEM_HANDLER_CAPSULE_NAME, NULL);
52345235
if (PyDataMem_DefaultHandler == NULL) {
52355236
goto err;
52365237
}

numpy/_core/tests/__init__.py

Whitespace-only changes.

numpy/_core/tests/test_mem_policy.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,16 @@ def get_module(tmp_path):
4646
Py_DECREF(secret_data);
4747
return old;
4848
"""),
49+
("set_wrong_capsule_name_data_policy", "METH_NOARGS", """
50+
PyObject *wrong_name_capsule =
51+
PyCapsule_New(&secret_data_handler, "not_mem_handler", NULL);
52+
if (wrong_name_capsule == NULL) {
53+
return NULL;
54+
}
55+
PyObject *old = PyDataMem_SetHandler(wrong_name_capsule);
56+
Py_DECREF(wrong_name_capsule);
57+
return old;
58+
"""),
4959
("set_old_policy", "METH_O", """
5060
PyObject *old;
5161
if (args != NULL && PyCapsule_CheckExact(args)) {
@@ -252,6 +262,10 @@ def test_set_policy(get_module):
252262
get_module.set_old_policy(orig_policy)
253263
assert get_handler_name() == orig_policy_name
254264

265+
with pytest.raises(ValueError,
266+
match="Capsule must be named 'mem_handler'"):
267+
get_module.set_wrong_capsule_name_data_policy()
268+
255269

256270
@pytest.mark.skipif(sys.version_info >= (3, 12), reason="no numpy.distutils")
257271
def test_default_policy_singleton(get_module):

0 commit comments

Comments
 (0)