Skip to content

Commit a0a7fc1

Browse files
committed
BUG: Ensure that scalar binops prioritize __array_ufunc__
If array-ufunc is implemented, we must call always use it for all operators (that seems to be the promise). If __array_function__ is defined we are in the clear w.r.t. recursion because the object is either an array (can be unpacked, but already checked earlier now also), or it cannot call the ufunc without unpacking itself (otherwise it would cause recursion). There is an oddity about `__array_wrap__`. Rather than trying to do odd things to deal with it, I added a comment explaining why it doens't matter (roughly: don't use our scalar priority if you want to be sure to get a chance).
1 parent e4a495d commit a0a7fc1

File tree

3 files changed

+38
-5
lines changed

3 files changed

+38
-5
lines changed

numpy/_core/src/multiarray/scalartypes.c.src

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,15 +194,30 @@ find_binary_operation_path(
194194
PyLong_CheckExact(other) ||
195195
PyFloat_CheckExact(other) ||
196196
PyComplex_CheckExact(other) ||
197-
PyBool_Check(other)) {
197+
PyBool_Check(other) ||
198+
PyArray_Check(other)) {
198199
/*
199200
* The other operand is ready for the operation already. Must pass on
200201
* on float/long/complex mainly for weak promotion (NEP 50).
201202
*/
202-
Py_INCREF(other);
203-
*other_op = other;
203+
*other_op = Py_NewRef(other);
204204
return 0;
205205
}
206+
/*
207+
* If other has __array_ufunc__ always use ufunc. If array-ufunc was None
208+
* we already deferred. And any custom object with array-ufunc cannot call
209+
* our ufuncs without preventing recursion.
210+
* It may be nice to avoid double lookup in `BINOP_GIVE_UP_IF_NEEDED`.
211+
*/
212+
PyObject *attr = PyArray_LookupSpecial(other, npy_interned_str.array_ufunc);
213+
if (attr != NULL) {
214+
Py_DECREF(attr);
215+
*other_op = Py_NewRef(other);
216+
return 0;
217+
}
218+
else if (PyErr_Occurred()) {
219+
PyErr_Clear(); /* TODO[gh-14801]: propagate crashes during attribute access? */
220+
}
206221

207222
/*
208223
* Now check `other`. We want to know whether it is an object scalar
@@ -216,7 +231,13 @@ find_binary_operation_path(
216231
}
217232

218233
if (!was_scalar || PyArray_DESCR(arr)->type_num != NPY_OBJECT) {
219-
/* The array is OK for usage and we can simply forward it
234+
/*
235+
* The array is OK for usage and we can simply forward it. There
236+
* is a theoretical subtlety here: If the other object implements
237+
* `__array_wrap__`, we may ignore that. However, this only matters
238+
* if the other object has the identical `__array_priority__` and
239+
* additionally already deferred back to us.
240+
* (`obj + scalar` and `scalar + obj` are not symmetric.)
220241
*
221242
* NOTE: Future NumPy may need to distinguish scalars here, one option
222243
* could be marking the array.

numpy/_core/tests/test_multiarray.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4025,6 +4025,18 @@ class LowPriority(np.ndarray):
40254025
assert res.shape == (3,)
40264026
assert res[0] == 'result'
40274027

4028+
@pytest.mark.parametrize("scalar", [
4029+
np.longdouble(1), np.timedelta64(120, 'm')])
4030+
@pytest.mark.parametrize("op", [operator.add, operator.xor])
4031+
def test_scalar_binop_guarantees_ufunc(self, scalar, op):
4032+
# Test that __array_ufunc__ will always cause ufunc use even when
4033+
# we have to protect some other calls from recursing (see gh-26904).
4034+
class SomeClass:
4035+
def __array_ufunc__(self, ufunc, method, *inputs, **kw):
4036+
return "result"
4037+
4038+
assert SomeClass() + np.longdouble(1) == "result"
4039+
assert np.longdouble(1) + SomeClass() == "result"
40284040

40294041
def test_ufunc_override_normalize_signature(self):
40304042
# gh-5674

numpy/_core/tests/test_scalarmath.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
floating_types = np.floating.__subclasses__()
2828
complex_floating_types = np.complexfloating.__subclasses__()
2929

30-
objecty_things = [object(), None]
30+
objecty_things = [object(), None, np.array(None, dtype=object)]
3131

3232
binary_operators_for_scalars = [
3333
operator.lt, operator.le, operator.eq, operator.ne, operator.ge,

0 commit comments

Comments
 (0)