Skip to content

Commit 9056d1f

Browse files
committed
BUG: avoid segfault on bad arguments in ndarray.__array_function__
1 parent 5c8f540 commit 9056d1f

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

numpy/_core/src/multiarray/methods.c

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1120,7 +1120,14 @@ array_function(PyArrayObject *NPY_UNUSED(self), PyObject *c_args, PyObject *c_kw
11201120
&func, &types, &args, &kwargs)) {
11211121
return NULL;
11221122
}
1123-
1123+
if (!PyTuple_CheckExact(args)) {
1124+
PyErr_SetString(PyExc_TypeError, "args must be a tuple.");
1125+
return NULL;
1126+
}
1127+
if (!PyDict_CheckExact(kwargs)) {
1128+
PyErr_SetString(PyExc_TypeError, "kwargs must be a dict.");
1129+
return NULL;
1130+
}
11241131
types = PySequence_Fast(
11251132
types,
11261133
"types argument to ndarray.__array_function__ must be iterable");

numpy/_core/tests/test_overrides.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,14 @@ def test_no_wrapper(self):
203203
array.__array_function__(func=func, types=(np.ndarray,),
204204
args=(array,), kwargs={})
205205

206+
def test_wrong_arguments(self):
207+
# Check our implementation guards against wrong arguments.
208+
a = np.array([1, 2])
209+
with pytest.raises(TypeError, match="args must be a tuple"):
210+
a.__array_function__(np.reshape, (np.ndarray,), a, (2, 1))
211+
with pytest.raises(TypeError, match="kwargs must be a dict"):
212+
a.__array_function__(np.reshape, (np.ndarray,), (a,), (2, 1))
213+
206214

207215
class TestArrayFunctionDispatch:
208216

0 commit comments

Comments
 (0)